Skip to content

Commit

Permalink
[SPARK-14007] [SQL] Manage the memory used by hash map in shuffled ha…
Browse files Browse the repository at this point in the history
…sh join

## What changes were proposed in this pull request?

This PR try acquire the memory for hash map in shuffled hash join, fail the task if there is no enough memory (otherwise it could OOM the executor).

It also removed unused HashedRelation.

## How was this patch tested?

Existing unit tests. Manual tests with TPCDS Q78.

Author: Davies Liu <[email protected]>

Closes #11826 from davies/cleanup_hash2.
  • Loading branch information
Davies Liu authored and davies committed Mar 21, 2016
1 parent 5d8de16 commit 9b4e15b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
// there is no enough memory actually, it means the actual free memory is smaller than
// MemoryManager thought, we should keep the acquired memory.
acquiredButNotUsed += acquired;
synchronized (this) {
acquiredButNotUsed += acquired;
allocatedPages.clear(pageNumber);
}
// this could trigger spilling to free some pages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
* of the join will be broadcasted and the other side will be streamed, with no shuffling
* performed. If both sides of the join are eligible to be broadcasted then the
* - Shuffle hash join: if single partition is small enough to build a hash table.
* - Shuffle hash join: if the average size of a single partition is small enough to build a hash
* table.
* - Sort merge: if the matching join keys are sortable.
*/
object EquiJoinSelection extends Strategy with PredicateHelper {

/**
* Matches a plan whose single partition should be small enough to build a hash table.
*
* Note: this assume that the number of partition is fixed, requires addtional work if it's
* dynamic.
*/
def canBuildHashMap(plan: LogicalPlan): Boolean = {
plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,6 @@ private[execution] trait UniqueHashedRelation extends HashedRelation {
}
}

/**
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
private[joins] class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {

// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)

override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
}

override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}


/**
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
private[joins] class UniqueKeyHashedRelation(
private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends UniqueHashedRelation with Externalizable {

// Needed for serialization (it is public to make Java serialization work)
def this() = this(null)

override def getValue(key: InternalRow): InternalRow = hashTable.get(key)

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
}

override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
}
}


private[execution] object HashedRelation {

/**
Expand All @@ -162,51 +117,16 @@ private[execution] object HashedRelation {
* Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {

if (keyGenerator.isInstanceOf[UnsafeProjection]) {
return UnsafeHashedRelation(
input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}

// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null

// Whether the join key is unique. If the key is unique, we can convert the underlying
// hash map into one specialized for this.
var keyIsUnique = true

// Create a mapping of buildKeys -> rows
while (input.hasNext) {
currentRow = input.next()
val rowKey = keyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[InternalRow]()
hashTable.put(rowKey.copy(), newMatchList)
newMatchList
} else {
keyIsUnique = false
existingMatchList
}
matchList += currentRow
}
}

if (keyIsUnique) {
val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
uniqHashTable.put(entry.getKey, entry.getValue()(0))
}
new UniqueKeyHashedRelation(uniqHashTable)
if (canJoinKeyFitWithinLong) {
LongHashedRelation(input, keyGenerator, sizeEstimate)
} else {
new GeneralHashedRelation(hashTable)
UnsafeHashedRelation(
input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}
}
}
Expand Down Expand Up @@ -428,6 +348,7 @@ private[joins] object UnsafeHashedRelation {
sizeEstimate: Int): HashedRelation = {

// Use a Java hash table here because unsafe maps expect fixed size records
// TODO: Use BytesToBytesMap for memory efficiency
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)

// Create a mapping of buildKeys -> rows
Expand Down Expand Up @@ -683,11 +604,7 @@ private[execution] case class HashedRelationBroadcastMode(

override def transform(rows: Array[InternalRow]): HashedRelation = {
val generator = UnsafeProjection.create(keys, attributes)
if (canJoinKeyFitWithinLong) {
LongHashedRelation(rows.iterator, generator, rows.length)
} else {
HashedRelation(rows.iterator, generator, rows.length)
}
HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
}

private lazy val canonicalizedKeys: Seq[Expression] = {
Expand All @@ -703,4 +620,3 @@ private[execution] case class HashedRelationBroadcastMode(
case _ => false
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.memory.MemoryMode
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
* Performs an inner hash join of two child relations by first shuffling the data using the join
* keys.
* Performs a hash join of two child relations by first shuffling the data using the join keys.
*/
case class ShuffledHashJoin(
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -55,11 +56,45 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
val context = TaskContext.get()
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
context.addTaskCompletionListener((t: TaskContext) =>
memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
)

val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
// TODO: use BytesToBytesMap instead of HashMap for memory efficiency
// Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
acquired += got
}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.
row.copy()
}

HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
val joinedRow = new JoinedRow
joinType match {
case Inner =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,6 @@ import org.apache.spark.util.collection.CompactBuffer

class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {

// Key is simply the record itself
private val keyProjection = new Projection {
override def apply(row: InternalRow): InternalRow = row
}

test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])

assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(InternalRow(10)) === null)

val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
assert(hashed.get(data(2)) === data2)
}

test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])

assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
assert(hashed.get(InternalRow(10)) === null)

val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
assert(uniqHashed.getValue(data(0)) === data(0))
assert(uniqHashed.getValue(data(1)) === data(1))
assert(uniqHashed.getValue(data(2)) === data(2))
assert(uniqHashed.getValue(InternalRow(10)) === null)
}

test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
Expand Down

0 comments on commit 9b4e15b

Please sign in to comment.