Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19256][SQL] Hive bucketing support #19001

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ private[spark] class CoalescedRDD[T: ClassTag](
partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
}
}
/**
* Coalesce the partitions of a parent RDD into fewer partitions, so that each partition of
* this RDD computes one or more of the parent ones. Every i'th partition of the parent RDD is
* mapped to (i % targetPartitions)'th partition of the output RDD
*/
private[spark] class RoundRobinPartitionCoalescer() extends PartitionCoalescer with Serializable {
def coalesce(targetPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = {
val partitionGroups = ArrayBuffer[PartitionGroup]()
for (_ <- 0 until targetPartitions) {
partitionGroups += new PartitionGroup(None)
}

for ((p, i) <- parent.partitions.zipWithIndex) {
partitionGroups(i % targetPartitions).partitions += p
}
partitionGroups.toArray
}
}

/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.physical

import scala.language.existentials

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -83,6 +85,11 @@ case class ClusteredDistribution(
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

require(
requiredNumPartitions.isEmpty || requiredNumPartitions.get >= 0,
"If the required number of partitions is defined for ClusteredDistribution, it should be a " +
" non negative number but " + requiredNumPartitions.get + " was provided")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
Expand All @@ -99,17 +106,28 @@ case class ClusteredDistribution(
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
case class HashClusteredDistribution(
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None,
hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash])
extends Distribution {

require(
expressions != Nil,
"The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def requiredNumPartitions: Option[Int] = None
require(
requiredNumPartitions.isEmpty || requiredNumPartitions.get >= 0,
"If the required number of partitions is defined for ClusteredDistribution, it should be a " +
" non negative number but " + requiredNumPartitions.get + " was provided")

override def createPartitioning(numPartitions: Int): Partitioning = {
HashPartitioning(expressions, numPartitions)
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions, hashingFunctionClass)
}
}

Expand Down Expand Up @@ -198,7 +216,10 @@ case object SinglePartition extends Partitioning {
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case class HashPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash])
extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
Expand All @@ -209,9 +230,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
super.satisfies(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
h.hashingFunctionClass == hashingFunctionClass &&
(h.requiredNumPartitions.isEmpty || h.requiredNumPartitions.get == numPartitions) &&
expressions.length == h.expressions.length &&
expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) }

case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
Expand All @@ -222,9 +245,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

/**
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
* than numPartitions) based on hashing expressions.
* than numPartitions) based on hashing expression(s) and the hashing function.
*/
def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
def partitionIdExpression: Expression = {
val hashExpression = hashingFunctionClass match {
case m if m == classOf[Murmur3Hash] => new Murmur3Hash(expressions)
case h if h == classOf[HiveHash] => HiveHash(expressions)
case _ => throw new Exception(s"Unsupported hashingFunction: $hashingFunctionClass")
}
Pmod(hashExpression, Literal(numPartitions))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.SparkFunSuite
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{HiveHash, Murmur3Hash}
import org.apache.spark.sql.catalyst.plans.physical._

class DistributionSuite extends SparkFunSuite {
Expand Down Expand Up @@ -79,6 +80,26 @@ class DistributionSuite extends SparkFunSuite {
ClusteredDistribution(Seq('d, 'e)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
HashClusteredDistribution(Seq('a, 'b, 'c), Some(10), classOf[Murmur3Hash]),
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
HashClusteredDistribution(Seq('a, 'b, 'c), Some(12), classOf[Murmur3Hash]),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
HashClusteredDistribution(Seq('d, 'e), Some(10), classOf[Murmur3Hash]),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
HashClusteredDistribution(Seq('a, 'b, 'c), Some(10), classOf[HiveHash]),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
AllTuples,
Expand Down Expand Up @@ -125,21 +146,6 @@ class DistributionSuite extends SparkFunSuite {
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'b, 'a)),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
true)

// Cases which need an exchange between two data properties.
// TODO: We can have an optimization to first sort the dataset
// by a.asc and then sort b, and c in a partition. This optimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.FileFormatWriter
Expand Down Expand Up @@ -60,5 +61,9 @@ trait DataWritingCommand extends Command {
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
}

def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution)

def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
Expand All @@ -43,7 +44,13 @@ trait RunnableCommand extends Command {
// `ExecutedCommand` during query planning.
lazy val metrics: Map[String, SQLMetric] = Map.empty

def run(sparkSession: SparkSession): Seq[Row]
def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecutedCommandExec doesn't call it.

throw new NotImplementedError
}

def run(sparkSession: SparkSession): Seq[Row] = {
throw new NotImplementedError
}
}

/**
Expand Down Expand Up @@ -112,6 +119,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

override def nodeName: String = "Execute " + cmd.nodeName

override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution

override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}

Expand Down Expand Up @@ -109,7 +108,7 @@ object FileFormatWriter extends Logging {
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
bucketIdExpression: Option[Expression],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String])
: Set[String] = {
Expand All @@ -122,17 +121,6 @@ object FileFormatWriter extends Logging {
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)

val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val caseInsensitiveOptions = CaseInsensitiveMap(options)

// Note: prepareWrite has side effect. It sets "job".
Expand All @@ -156,40 +144,14 @@ object FileFormatWriter extends Logging {
statsTrackers = statsTrackers
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we send an individual PR to do this? i.e. do the sorting via requiredOrdering instead of doing it manually.

// the sort order doesn't matter
val actualOrdering = plan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

SQLExecution.checkSQLExecutionId(sparkSession)

// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val rdd = if (orderingMatched) {
plan.execute()
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending))
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
SortExec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing SortExec here and adding it in EnsureRequirements Strategy will have impact on many other DataWritingCommands which depends on FileFormatWriter, like CreateDataSourceTableAsSelectCommand. To fix it code changes are needed onto such DataWritingCommand implementations to export requiredDistribution and requiredOrdering.

orderingExpr,
global = false,
child = plan).execute()
}
val rdd = plan.execute()
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
Expand All @@ -202,7 +164,7 @@ object FileFormatWriter extends Logging {
committer,
iterator = iter)
},
0 until rdd.partitions.length,
rdd.partitions.indices,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
Expand Down Expand Up @@ -521,18 +483,18 @@ object FileFormatWriter extends Logging {
var recordsInFile: Long = 0L
var fileCounter = 0
val updatedPartitions = mutable.Set[String]()
var currentPartionValues: Option[UnsafeRow] = None
var currentPartitionValues: Option[UnsafeRow] = None
var currentBucketId: Option[Int] = None

for (row <- iter) {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None

if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
if (isPartitioned && currentPartitionValues != nextPartitionValues) {
currentPartitionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartitionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
Expand All @@ -543,7 +505,7 @@ object FileFormatWriter extends Logging {
fileCounter = 0

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
} else if (desc.maxRecordsPerFile > 0 &&
recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
Expand All @@ -554,7 +516,7 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
}
val outputRow = getOutputRow(row)
currentWriter.write(outputRow)
Expand Down
Loading