diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f4..871c2a4121d34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -123,6 +123,24 @@ private[spark] class CoalescedRDD[T: ClassTag]( partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq } } +/** + * Coalesce the partitions of a parent RDD into fewer partitions, so that each partition of + * this RDD computes one or more of the parent ones. Every i'th partition of the parent RDD is + * mapped to (i % targetPartitions)'th partition of the output RDD + */ +private[spark] class RoundRobinPartitionCoalescer() extends PartitionCoalescer with Serializable { + def coalesce(targetPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + val partitionGroups = ArrayBuffer[PartitionGroup]() + for (_ <- 0 until targetPartitions) { + partitionGroups += new PartitionGroup(None) + } + + for ((p, i) <- parent.partitions.zipWithIndex) { + partitionGroups(i % targetPartitions).partitions += p + } + partitionGroups.toArray + } +} /** * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0189bd73c56bf..e48147f05bf9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import scala.language.existentials + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -83,6 +85,11 @@ case class ClusteredDistribution( "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + require( + requiredNumPartitions.isEmpty || requiredNumPartitions.get >= 0, + "If the required number of partitions is defined for ClusteredDistribution, it should be a " + + " non negative number but " + requiredNumPartitions.get + " was provided") + override def createPartitioning(numPartitions: Int): Partitioning = { assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + @@ -99,17 +106,28 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None, + hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash]) + extends Distribution { + require( expressions != Nil, "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None + require( + requiredNumPartitions.isEmpty || requiredNumPartitions.get >= 0, + "If the required number of partitions is defined for ClusteredDistribution, it should be a " + + " non negative number but " + requiredNumPartitions.get + " was provided") override def createPartitioning(numPartitions: Int): Partitioning = { - HashPartitioning(expressions, numPartitions) + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(expressions, numPartitions, hashingFunctionClass) } } @@ -198,7 +216,10 @@ case object SinglePartition extends Partitioning { * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning( + expressions: Seq[Expression], + numPartitions: Int, + hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash]) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -209,9 +230,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) super.satisfies(required) || { required match { case h: HashClusteredDistribution => - expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { - case (l, r) => l.semanticEquals(r) - } + h.hashingFunctionClass == hashingFunctionClass && + (h.requiredNumPartitions.isEmpty || h.requiredNumPartitions.get == numPartitions) && + expressions.length == h.expressions.length && + expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) @@ -222,9 +245,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) /** * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less - * than numPartitions) based on hashing expressions. + * than numPartitions) based on hashing expression(s) and the hashing function. */ - def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + def partitionIdExpression: Expression = { + val hashExpression = hashingFunctionClass match { + case m if m == classOf[Murmur3Hash] => new Murmur3Hash(expressions) + case h if h == classOf[HiveHash] => HiveHash(expressions) + case _ => throw new Exception(s"Unsupported hashingFunction: $hashingFunctionClass") + } + Pmod(hashExpression, Literal(numPartitions)) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..726d23f09a28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{HiveHash, Murmur3Hash} import org.apache.spark.sql.catalyst.plans.physical._ class DistributionSuite extends SparkFunSuite { @@ -79,6 +80,26 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10), classOf[Murmur3Hash]), + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(12), classOf[Murmur3Hash]), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('d, 'e), Some(10), classOf[Murmur3Hash]), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10), classOf[HiveHash]), + false) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), AllTuples, @@ -125,21 +146,6 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'b, 'a)), - true) - - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('b, 'c, 'a, 'd)), - true) - // Cases which need an exchange between two data properties. // TODO: We can have an optimization to first sort the dataset // by a.asc and then sort b, and c in a partition. This optimization diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..69bbce086d3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -21,8 +21,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter @@ -60,5 +61,9 @@ trait DataWritingCommand extends Command { new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) } + def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + + def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2cc0e38adc2ee..52f91ab1e66a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -43,7 +44,13 @@ trait RunnableCommand extends Command { // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty - def run(sparkSession: SparkSession): Seq[Row] + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + throw new NotImplementedError + } + + def run(sparkSession: SparkSession): Seq[Row] = { + throw new NotImplementedError + } } /** @@ -112,6 +119,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) override def nodeName: String = "Execute " + cmd.nodeName + override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..3f1dc1c5198c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,12 +34,11 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -109,7 +108,7 @@ object FileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], - bucketSpec: Option[BucketSpec], + bucketIdExpression: Option[Expression], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String]) : Set[String] = { @@ -122,17 +121,6 @@ object FileFormatWriter extends Logging { val partitionSet = AttributeSet(partitionColumns) val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains) - val bucketIdExpression = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } - val caseInsensitiveOptions = CaseInsensitiveMap(options) // Note: prepareWrite has side effect. It sets "job". @@ -156,19 +144,6 @@ object FileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns - // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } - } - SQLExecution.checkSQLExecutionId(sparkSession) // This call shouldn't be put into the `try` block below because it only initializes and @@ -176,20 +151,7 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - plan.execute() - } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)) - .map(BindReferences.bindReference(_, outputSpec.outputColumns)) - SortExec( - orderingExpr, - global = false, - child = plan).execute() - } + val rdd = plan.execute() val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( rdd, @@ -202,7 +164,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rdd.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res @@ -521,18 +483,18 @@ object FileFormatWriter extends Logging { var recordsInFile: Long = 0L var fileCounter = 0 val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None + var currentPartitionValues: Option[UnsafeRow] = None var currentBucketId: Option[Int] = None for (row <- iter) { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) } if (isBucketed) { currentBucketId = nextBucketId @@ -543,7 +505,7 @@ object FileFormatWriter extends Logging { fileCounter = 0 releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } else if (desc.maxRecordsPerFile > 0 && recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. @@ -554,7 +516,7 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } val outputRow = getOutputRow(row) currentWriter.write(outputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index dd7ef0d15c140..dc8a49fa8b61d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -25,8 +25,9 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, HiveHash, Murmur3Hash, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -150,6 +151,10 @@ case class InsertIntoHadoopFsRelationCommand( } } + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, @@ -160,7 +165,7 @@ case class InsertIntoHadoopFsRelationCommand( qualifiedOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = bucketSpec, + bucketIdExpression = bucketIdExpression, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options) @@ -184,6 +189,43 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning( + bucketColumns, + spec.numBuckets, + classOf[Murmur3Hash] + ).partitionIdExpression + } + } + + /** + * How is `requiredOrdering` determined ? + * + * table type | requiredOrdering + * -----------------+------------------------------------------------- + * normal table | partition columns + * bucketed table | (partition columns + bucketId + sort columns) + * -----------------+------------------------------------------------- + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + val sortExpressions = bucketSpec match { + case Some(spec) => + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val sortColumns = spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + partitionColumns ++ bucketIdExpression ++ sortColumns + + case _ => partitionColumns + } + Seq(sortExpressions.map(SortOrder(_, Ascending))) + } + /** * Deletes all partition files that match the specified static prefix. Partitions with custom * locations are also cleared based on the custom locations map given to this class. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..94dbd21578f20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -243,13 +244,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { leftPartitioning match { - case HashPartitioning(leftExpressions, _) + case HashPartitioning(leftExpressions, _, _) if leftExpressions.length == leftKeys.length && leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, leftExpressions, leftKeys) case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) + case HashPartitioning(rightExpressions, _, _) if rightExpressions.length == rightKeys.length && rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, rightExpressions, rightKeys) @@ -262,14 +263,54 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } + private def adjustHashingFunctionAndNumPartitions( + plan: SparkPlan, + requiredNumPartitions: Option[Int], + hashingFunctionClass: Class[_ <: HashExpression[Int]]): + (Option[Int], Class[_ <: HashExpression[Int]]) = { + + val childHashPartitionings = plan.children.map(_.outputPartitioning) + .filter(_.isInstanceOf[HashPartitioning]) + .map(_.asInstanceOf[HashPartitioning]) + + val distinctRequiredNumPartitions = childHashPartitionings.map(_.numPartitions).distinct + val newRequiredNumPartitions = + if (distinctRequiredNumPartitions.nonEmpty && distinctRequiredNumPartitions.size == 1) { + Some(distinctRequiredNumPartitions.head) + } else { + requiredNumPartitions + } + + val distinctHashingFunctions = childHashPartitionings.map(_.hashingFunctionClass).distinct + val newHashingFunctionClass = + if (distinctHashingFunctions.nonEmpty && distinctHashingFunctions.size == 1) { + distinctHashingFunctions.head + } else { + hashingFunctionClass + } + + (newRequiredNumPartitions, newHashingFunctionClass) + } + /** + * Based on the type of join and the properties of the child nodes, adjust following: + * + * [A] Join keys + * ----------------------------- * When the physical operators are created for JOIN, the ordering of join keys is based on order * in which the join keys appear in the user query. That might not match with the output * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the + * introduced). This method will change the ordering of the join keys to match with the * partitioning of the join nodes' children. + * + * [B] Hashing function class and required partitions for children + * -------------------------------------------------------------------- + * In case when children of the join node are already shuffled using the same hash function and + * have the same number of partitions, then let the join node use the same values (and not the + * default number of shuffle partitions and hashing function). This saves shuffling of the join + * nodes' children. */ - private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { + private def adjustJoinRequirements(plan: SparkPlan): SparkPlan = { plan.transformUp { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => @@ -278,16 +319,25 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, left, right) - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, + requiredNumPartitions, hashingFunctionClass) => + val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + val (newRequiredNumPartitions, newHashingFunctionClass) = + adjustHashingFunctionAndNumPartitions(plan, requiredNumPartitions, hashingFunctionClass) ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) + left, right, newRequiredNumPartitions, newHashingFunctionClass) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, + requiredNumPartitions, hashingFunctionClass) => - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + val (newRequiredNumPartitions, newHashingFunctionClass) = + adjustHashingFunctionAndNumPartitions(plan, requiredNumPartitions, hashingFunctionClass) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, + left, right, newRequiredNumPartitions, newHashingFunctionClass) } } @@ -299,6 +349,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case _ => operator } case operator: SparkPlan => - ensureDistributionAndOrdering(reorderJoinPredicates(operator)) + ensureDistributionAndOrdering(adjustJoinRequirements(operator)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5a1e217082bc2..010ba86340db0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -204,7 +204,7 @@ object ShuffleExchangeExec { serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => + case HashPartitioning(_, n, _) => new Partitioner { override def numPartitions: Int = n // For HashPartitioning, the partitioning key is already a valid partition ID, as we use diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f32..7ed569f11fd04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.joins +import scala.language.existentials + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, HashExpression, Murmur3Hash} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} @@ -36,7 +38,9 @@ case class ShuffledHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) + right: SparkPlan, + requiredNumPartitions: Option[Int] = None, + hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash]) extends BinaryExecNode with HashJoin { override lazy val metrics = Map( @@ -45,8 +49,11 @@ case class ShuffledHashJoinExec( "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + HashClusteredDistribution(leftKeys, requiredNumPartitions, hashingFunctionClass) :: + HashClusteredDistribution(rightKeys, requiredNumPartitions, hashingFunctionClass) :: + Nil + } private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..aaddb2558aa42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -39,7 +40,10 @@ case class SortMergeJoinExec( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with CodegenSupport { + right: SparkPlan, + requiredNumPartitions: Option[Int] = None, + hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash]) + extends BinaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -77,8 +81,11 @@ case class SortMergeJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + HashClusteredDistribution(leftKeys, requiredNumPartitions, hashingFunctionClass) :: + HashClusteredDistribution(rightKeys, requiredNumPartitions, hashingFunctionClass) :: + Nil + } override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..bdc99a4085f0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -128,7 +128,7 @@ class FileStreamSink( outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = None, + bucketIdExpression = None, statsTrackers = Nil, options = options) } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkInputFormat.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkInputFormat.java new file mode 100644 index 0000000000000..51004b15e1738 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkInputFormat.java @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io; + +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapred.*; +import org.apache.hadoop.util.StringUtils; + +import java.io.IOException; +import java.util.Arrays; + +import static org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR; + +/** + * A {@link InputFormat} implementation for reading bucketed tables. + * + * We cannot directly use {@link BucketizedHiveInputFormat} from Hive as it depends on the + * map-reduce plan to get required information for split generation. + */ +public class BucketizedSparkInputFormat + extends BucketizedHiveInputFormat { + + private static final String FILE_INPUT_FORMAT = "file.inputformat"; + + @Override + public RecordReader getRecordReader( + InputSplit split, + JobConf job, + Reporter reporter) throws IOException { + + BucketizedHiveInputSplit hsplit = (BucketizedHiveInputSplit) split; + String inputFormatClassName = null; + Class inputFormatClass = null; + + try { + inputFormatClassName = hsplit.inputFormatClassName(); + inputFormatClass = job.getClassByName(inputFormatClassName); + } catch (ClassNotFoundException e) { + throw new IOException("Cannot find class " + inputFormatClassName, e); + } + + InputFormat inputFormat = getInputFormatFromCache(inputFormatClass, job); + return new BucketizedSparkRecordReader<>(inputFormat, hsplit, job, reporter); + } + + @Override + public InputSplit[] getSplits(JobConf job, int numBuckets) throws IOException { + final String inputFormatClassName = job.get(FILE_INPUT_FORMAT); + final String[] inputDirs = job.get(INPUT_DIR).split(StringUtils.COMMA_STR); + + if (inputDirs.length != 1) { + throw new IOException(this.getClass().getCanonicalName() + + " expects only one input directory. " + inputDirs.length + + " directories detected : " + Arrays.toString(inputDirs)); + } + + final String inputDir = inputDirs[0]; + final Path inputPath = new Path(inputDir); + final JobConf newJob = new JobConf(job); + final FileStatus[] listStatus = this.listStatus(newJob, inputPath); + final InputSplit[] result = new InputSplit[numBuckets]; + + if (listStatus.length != 0 && listStatus.length != numBuckets) { + throw new IOException("Bucketed path was expected to have " + numBuckets + " files but " + + listStatus.length + " files are present. Path = " + inputPath); + } + + try { + final Class inputFormatClass = Class.forName(inputFormatClassName); + final InputFormat inputFormat = getInputFormatFromCache(inputFormatClass, job); + newJob.setInputFormat(inputFormat.getClass()); + + for (int i = 0; i < numBuckets; i++) { + final FileStatus fileStatus = listStatus[i]; + FileInputFormat.setInputPaths(newJob, fileStatus.getPath()); + + final InputSplit[] inputSplits = inputFormat.getSplits(newJob, 0); + if (inputSplits != null && inputSplits.length > 0) { + result[i] = + new BucketizedHiveInputSplit(inputSplits, inputFormatClass.getName()); + } + } + } catch (ClassNotFoundException e) { + throw new IOException("Unable to find the InputFormat class " + inputFormatClassName, e); + } + return result; + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkRecordReader.java new file mode 100644 index 0000000000000..7cf0fede87678 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/BucketizedSparkRecordReader.java @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io; + +import org.apache.hadoop.hive.io.HiveIOExceptionHandlerUtil; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapred.*; + +import java.io.IOException; + +/** + * A {@link RecordReader} implementation for reading {@link BucketizedHiveInputSplit}. + * Each {@link BucketizedHiveInputSplit} packs multiple {@link InputSplit} in itself which + * correspond to a single bucket. This record reader would open the record reader for those + * splits one by one (transparent to the caller) and return records. + * + * The Hive counterpart for this class is {@link BucketizedHiveRecordReader}. The reason for not + * re-using {@link BucketizedHiveRecordReader} in Hive is because it relies on Map-Reduce plan + * which is not avaliable while we are running jobs in Spark. + */ +public class BucketizedSparkRecordReader + implements RecordReader { + + protected final BucketizedHiveInputSplit split; + protected final InputFormat inputFormat; + + private final Reporter reporter; + private long progress; + private int index; + + protected RecordReader recordReader; + protected JobConf jobConf; + + public BucketizedSparkRecordReader( + InputFormat inputFormat, + BucketizedHiveInputSplit bucketizedSplit, + JobConf jobConf, + Reporter reporter) throws IOException { + this.recordReader = null; + this.jobConf = jobConf; + this.split = bucketizedSplit; + this.inputFormat = inputFormat; + this.reporter = reporter; + initNextRecordReader(); + } + + /** + * Get the record reader for the next chunk + */ + private boolean initNextRecordReader() throws IOException { + if (recordReader != null) { + recordReader.close(); + recordReader = null; + if (index > 0) { + progress += split.getLength(index - 1); // done processing so far + } + } + + // if all chunks have been processed, nothing more to do. + if (index == split.getNumSplits()) { + return false; + } + + try { + // get a record reader for the index-th chunk + recordReader = inputFormat.getRecordReader(split.getSplit(index), jobConf, reporter); + } catch (Exception e) { + recordReader = HiveIOExceptionHandlerUtil.handleRecordReaderCreationException(e, jobConf); + } + + index++; + return true; + } + + private boolean doNextWithExceptionHandler(K key, V value) throws IOException { + try { + return recordReader.next(key, value); + } catch (Exception e) { + return HiveIOExceptionHandlerUtil.handleRecordReaderNextException(e, jobConf); + } + } + + @Override + public boolean next(K key, V value) throws IOException { + try { + while ((recordReader == null) || !doNextWithExceptionHandler(key, value)) { + if (!initNextRecordReader()) { + return false; + } + } + return true; + } catch (Exception e) { + throw new IOException(e); + } + } + + @Override + public K createKey() { + return (K) recordReader.createKey(); + } + + @Override + public V createValue() { + return (V) recordReader.createValue(); + } + + @Override + public long getPos() throws IOException { + if (recordReader != null) { + return recordReader.getPos(); + } else { + return 0; + } + } + + @Override + public void close() throws IOException { + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + index = 0; + } + + @Override + public float getProgress() throws IOException { + return Math.min(1.0f, (recordReader == null ? + progress : recordReader.getPos()) / (float) (split.getLength())); + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..f93838290c9fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.io.BucketizedSparkInputFormat import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer @@ -36,10 +37,11 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.rdd._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -50,9 +52,11 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] + def makeRDDForTable(hiveTable: HiveTable, bucketSpec: Option[BucketSpec]): RDD[InternalRow] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] + def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + bucketSpec: Option[BucketSpec]): RDD[InternalRow] } @@ -90,11 +94,14 @@ class HadoopTableReader( override def conf: SQLConf = sparkSession.sessionState.conf - override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = + override def makeRDDForTable( + hiveTable: HiveTable, + bucketSpec: Option[BucketSpec]): RDD[InternalRow] = makeRDDForTable( hiveTable, Utils.classForName(tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], - filterOpt = None) + filterOpt = None, + bucketSpec) /** * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed @@ -104,11 +111,14 @@ class HadoopTableReader( * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. * @param filterOpt If defined, then the filter is used to reject files contained in the data * directory being read. If None, then all files are accepted. + * @param bucketSpec If the table is bucketed and the read operation should reflect that, then + * this is used to ensure that the RDD respects bucketing */ def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[InternalRow] = { + filterOpt: Option[PathFilter], + bucketSpec: Option[BucketSpec]): RDD[InternalRow] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") @@ -121,10 +131,14 @@ class HadoopTableReader( val tablePath = hiveTable.getPath val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) - // logDebug("Table input: %s".format(tablePath)) - val ifc = hiveTable.getInputFormatClass - .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val hadoopRDD = createHadoopRdd(localTableDesc, inputPathStr, ifc) + val (minSplits, ifc) = bucketSpec match { + case Some(spec) => (spec.numBuckets, classOf[BucketizedSparkInputFormat[_, _]] + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]) + case None => (_minSplitsPerRDD, hiveTable.getInputFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]) + } + + val hadoopRDD = createHadoopRdd(localTableDesc, inputPathStr, ifc, minSplits) val attrsWithIndex = attributes.zipWithIndex val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) @@ -139,10 +153,12 @@ class HadoopTableReader( deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] = { + override def makeRDDForPartitionedTable( + partitions: Seq[HivePartition], + bucketSpec: Option[BucketSpec]): RDD[InternalRow] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap - makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) + makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None, bucketSpec) } /** @@ -154,10 +170,12 @@ class HadoopTableReader( * class to use to deserialize input Writables from the corresponding partition. * @param filterOpt If defined, then the filter is used to reject files contained in the data * subdirectory of each partition being read. If None, then all files are accepted. + * @param bucketSpec If defined, this is the bucketing specification for the table */ def makeRDDForPartitionedTable( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[InternalRow] = { + filterOpt: Option[PathFilter], + bucketSpec: Option[BucketSpec]): RDD[InternalRow] = { // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists def verifyPartitionPath( @@ -201,8 +219,14 @@ class HadoopTableReader( val partDesc = Utilities.getPartitionDesc(partition) val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) - val ifc = partDesc.getInputFileFormatClass - .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + + val (minSplits, ifc) = bucketSpec match { + case Some(spec) => (spec.numBuckets, classOf[BucketizedSparkInputFormat[_, _]] + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]) + case None => (_minSplitsPerRDD, partDesc.getInputFileFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]) + } + // Get partition field info val partSpec = partDesc.getPartSpec val partProps = partDesc.getProperties @@ -242,7 +266,7 @@ class HadoopTableReader( // Create local references so that the outer object isn't serialized. val localTableDesc = tableDesc - createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => + createHadoopRdd(localTableDesc, inputPathStr, ifc, minSplits).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema @@ -266,10 +290,20 @@ class HadoopTableReader( }.toSeq // Even if we don't use any partitions, we still need an empty RDD - if (hivePartitionRDDs.size == 0) { + if (hivePartitionRDDs.isEmpty) { new EmptyRDD[InternalRow](sparkSession.sparkContext) } else { - new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + val union = new UnionRDD(hivePartitionRDDs.head.context, hivePartitionRDDs) + + bucketSpec match { + case Some(spec) if partitionToDeserializer.size > 1 => + // If there are multiple Hive partitions of this table are read, the union RDD would + // have (num_partitions * num_buckets) partitions. Coalesce the union RDD in such a way + // that RDD partition for i-th bucket for each partition falls in one coalesced RDD + // partition. + new CoalescedRDD(union, spec.numBuckets, Some(new RoundRobinPartitionCoalescer)) + case _ => union + } } } @@ -294,7 +328,8 @@ class HadoopTableReader( private def createHadoopRdd( tableDesc: TableDesc, path: String, - inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { + inputFormatClass: Class[InputFormat[Writable, Writable]], + minSplits: Int): RDD[Writable] = { val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ @@ -305,7 +340,7 @@ class HadoopTableReader( inputFormatClass, classOf[Writable], classOf[Writable], - _minSplitsPerRDD) + minSplits) // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 4b923f5235a90..439726d696a4f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -50,7 +50,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.hive.HiveExternalCatalog.{DATASOURCE_SCHEMA, DATASOURCE_SCHEMA_NUMPARTS, DATASOURCE_SCHEMA_PART_PREFIX} +import org.apache.spark.sql.hive.HiveExternalCatalog.{DATASOURCE_PROVIDER, DATASOURCE_SCHEMA, + DATASOURCE_SCHEMA_NUMPARTS, DATASOURCE_SCHEMA_PART_PREFIX} import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -936,7 +937,10 @@ private[hive] object HiveClientImpl { } table.bucketSpec match { - case Some(bucketSpec) if DDLUtils.isHiveTable(table) => + case Some(bucketSpec) if DDLUtils.isHiveTable(table) || + (table.tableType != CatalogTableType.VIEW && + table.properties.get(DATASOURCE_PROVIDER).isEmpty) => + hiveTable.setNumBuckets(bucketSpec.numBuckets) hiveTable.setBucketCols(bucketSpec.bucketColumnNames.toList.asJava) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..51e8b814fb43a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, +UnknownPartitioning} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -178,16 +180,25 @@ case class HiveTableScanExec( prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) } + @transient private lazy val prunedPartitions = prunePartitions(rawPartitions) + protected override def doExecute(): RDD[InternalRow] = { + val bucketSpec = if (sparkSession.sessionState.conf.bucketingEnabled) { + relation.tableMeta.bucketSpec + } else { + None + } + // Using dummyCallSite, as getCallSite can turn out to be expensive with // with multiple partitions. val rdd = if (!relation.isPartitioned) { Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForTable(hiveQlTable) + hadoopReader.makeRDDForTable(hiveQlTable, bucketSpec) } } else { Utils.withDummyCallSite(sqlContext.sparkContext) { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(rawPartitions)) + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(rawPartitions), bucketSpec) } } val numOutputRows = longMetric("numOutputRows") @@ -203,6 +214,67 @@ case class HiveTableScanExec( } } + /** + * How is `outputPartitioning` determined ? + * ----------------------------------------- + * `HashPartitioning` would be used when ALL these criteria's match: + * + * - Table is bucketed + * - Bucketing is enabled + * - ALL the bucketing columns are being read from the table + * - In case of partitioned tables, if multiple partitions of the table are read, then they all + * should have same properties (eg. serde, input format class). + * + * How is `outputOrdering` determined ? + * ----------------------------------------- + * Sort ordering would be used when ALL these criteria's match: + * + * 1. `HashPartitioning` is being used + * 2. A prefix (or all) of the sort columns are being read from the table. + * 3. Table is non-partitioned OR only single partition of the table is read. + * In case of partitioned tables, if multiple partitions of the table are read, then the sort + * ordering is not used because the effect RDD partition for the bucket would comprise of + * multiple files. Even though the files are individually sorted, the RDD partition as a whole + * is NOT sorted. + * + * Sort ordering would be over the prefix subset of `sort columns` being read from the table. + * eg. + * Assume (col0, col2, col3) are the columns read from the table + * If sort columns are (col0, col1), then sort ordering would be considered as (col0) + * If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 above + */ + override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + relation.tableMeta.bucketSpec match { + case Some(spec) if sparkSession.sessionState.conf.bucketingEnabled => + def toAttribute(colName: String) = relation.dataCols.find(_.name == colName) + + val bucketColumns = spec.bucketColumnNames.flatMap(toAttribute) + val isMultiplePartitionScan = relation.isPartitioned && prunedPartitions.size > 1 + val allPropertiesSame = !isMultiplePartitionScan || + (prunedPartitions.map(_.getDeserializer.getClass).distinct.size == 1 && + prunedPartitions.map(_.getInputFormatClass.getClass).distinct.size == 1 && + prunedPartitions.map(_.getBucketCount).distinct.size == 1 && + prunedPartitions.map(_.getBucketCols).distinct.size == 1 && + prunedPartitions.map(_.getSortCols).distinct.size == 1) + + if (bucketColumns.size == spec.bucketColumnNames.size && allPropertiesSame) { + val partitioning = HashPartitioning(bucketColumns, spec.numBuckets, classOf[HiveHash]) + val sortColumns = spec.sortColumnNames.map(toAttribute).takeWhile(_.isDefined).map(_.get) + + val sortOrder = if (sortColumns.nonEmpty && !isMultiplePartitionScan) { + sortColumns.map(SortOrder(_, Ascending)) + } else { + Nil + } + (partitioning, sortOrder) + } else { + (UnknownPartitioning(0), Nil) + } + case _ => + (UnknownPartitioning(0), Nil) + } + } + override def doCanonicalize(): HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..7711e19844ffc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,16 +17,26 @@ package org.apache.spark.sql.hive.execution +import java.io.{File, IOException} +import java.net.URI +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Random} + +import scala.collection.mutable +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.{FileUtils, HiveStatsUtils} import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalog} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, HiveHash, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -164,25 +174,10 @@ case class InsertIntoHiveTable( } } - table.bucketSpec match { - case Some(bucketSpec) => - // Writes to bucketed hive tables are allowed only if user does not care about maintaining - // table's bucketing ie. both "hive.enforce.bucketing" and "hive.enforce.sorting" are - // set to false - val enforceBucketingConfig = "hive.enforce.bucketing" - val enforceSortingConfig = "hive.enforce.sorting" - - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + - "currently does NOT populate bucketed output which is compatible with Hive." - - if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || - hadoopConf.get(enforceSortingConfig, "true").toBoolean) { - throw new AnalysisException(message) - } else { - logWarning(message + s" Inserting data anyways since both $enforceBucketingConfig and " + - s"$enforceSortingConfig are set to false.") - } - case _ => // do nothing since table has no bucketing + if (!overwrite && table.bucketSpec.isDefined) { + throw new SparkException("Appending data to hive bucketed table is not allowed as it " + + "will break the table's bucketing guarantee. Consider overwriting instead. Table = " + + table.qualifiedName) } val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => @@ -201,6 +196,17 @@ case class InsertIntoHiveTable( allColumns = outputColumns, partitionAttributes = partitionAttributes) + // validate bucketing based on number of files before loading to metastore + table.bucketSpec.foreach { spec => + if (partition.nonEmpty && numDynamicPartitions > 0) { + val validPartitionPaths = + getValidPartitionPaths(hadoopConf, tmpLocation, numDynamicPartitions) + validateBuckets(hadoopConf, validPartitionPaths, table.bucketSpec.get.numBuckets) + } else { + validateBuckets(hadoopConf, Seq(tmpLocation), table.bucketSpec.get.numBuckets) + } + } + if (partition.nonEmpty) { if (numDynamicPartitions > 0) { externalCatalog.loadDynamicPartitions( @@ -216,10 +222,10 @@ case class InsertIntoHiveTable( // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries // scalastyle:on val oldPart = - externalCatalog.getPartitionOption( - table.database, - table.identifier.table, - partitionSpec) + externalCatalog.getPartitionOption( + table.database, + table.identifier.table, + partitionSpec) var doHiveOverwrite = overwrite @@ -265,4 +271,130 @@ case class InsertIntoHiveTable( isSrcLocal = false) } } + + private def getValidPartitionPaths( + conf: Configuration, + outputPath: Path, + numDynamicPartitions: Int): Seq[Path] = { + val validPartitionPaths = mutable.HashSet[Path]() + try { + val fs = outputPath.getFileSystem(conf) + HiveStatsUtils.getFileStatusRecurse(outputPath, numDynamicPartitions, fs) + .filter(_.isDirectory) + .foreach(d => validPartitionPaths.add(d.getPath)) + } catch { + case e: IOException => + throw new SparkException("Unable to extract partition paths from temporary output " + + s"location $outputPath due to : ${e.getMessage}", e) + } + validPartitionPaths.toSeq + } + + private def validateBuckets(conf: Configuration, outputPaths: Seq[Path], numBuckets: Int) = { + val bucketedFilePattern = """part-(\d+)(?:.*)?$""".r + + def getBucketIdFromFilename(fileName : String): Option[Int] = + fileName match { + case bucketedFilePattern(bucketId) => Some(bucketId.toInt) + case _ => None + } + + outputPaths.foreach(outputPath => { + val fs = outputPath.getFileSystem(conf) + val allFiles = fs.listStatus(outputPath) + if (allFiles != null && allFiles.nonEmpty) { + val files = allFiles.filterNot(_.getPath.getName == "_SUCCESS") + .map(_.getPath.getName) + .sortBy(_.toString) + + var expectedBucketId = 0 + files.foreach { case file => + getBucketIdFromFilename(file) match { + case Some(id) if id == expectedBucketId => + expectedBucketId += 1 + case Some(_) => + throw new SparkException( + s"Potentially missing bucketed output files in temporary bucketed output " + + s"location. Aborting job. Output location : $outputPath, files found : " + + files.mkString("[", ",", "]")) + case None => + throw new SparkException( + s"Invalid file found in temporary bucketed output location. Aborting job. " + + s"Output location : $outputPath, bad file : $file") + } + } + + if (expectedBucketId != numBuckets) { + throw new SparkException( + s"Potentially missing bucketed output files in temporary bucketed output location. " + + s"Aborting job. Output location : $outputPath, files found : " + + files.mkString("[", ",", "]")) + } + } + }) + } + + private def getPartitionAndDataColumns: (Seq[Attribute], Seq[Attribute]) = { + val allColumns = query.output + val partitionColumnNames = partition.keySet + allColumns.partition(c => partitionColumnNames.contains(c.name)) + } + + /** + * Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + * guarantee the data distribution is same between shuffle and bucketed data source, which + * enables us to only shuffle one side when join a bucketed table and a normal one. + */ + private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = + table.bucketSpec.map { spec => + HashPartitioning( + spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get), + spec.numBuckets, + classOf[HiveHash] + ).partitionIdExpression + } + + /** + * If the table is bucketed, then requiredDistribution would be the bucket columns. + * Else it would be empty + */ + override def requiredDistribution: Seq[Distribution] = table.bucketSpec match { + case Some(bucketSpec) => + val (_, dataColumns) = getPartitionAndDataColumns + Seq(HashClusteredDistribution( + bucketSpec.bucketColumnNames.map(b => dataColumns.find(_.name == b).get), + Option(bucketSpec.numBuckets), + classOf[HiveHash])) + + case _ => Seq(UnspecifiedDistribution) + } + + /** + * How is `requiredOrdering` determined ? + * + * table type | normal table | bucketed table + * --------------------+--------------------+----------------------------------------------- + * non-partitioned | Nil | sort columns + * static partition | Nil | sort columns + * dynamic partition | partition columns | (partition columns + bucketId + sort columns) + * --------------------+--------------------+----------------------------------------------- + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + val (partitionColumns, dataColumns) = getPartitionAndDataColumns + val isDynamicPartitioned = + table.partitionColumnNames.nonEmpty && partition.values.exists(_.isEmpty) + + val sortExpressions = table.bucketSpec match { + case Some(bucketSpec) => + val sortColumns = bucketSpec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + if (isDynamicPartitioned) { + partitionColumns ++ getBucketIdExpression(dataColumns) ++ sortColumns + } else { + sortColumns + } + case _ => if (isDynamicPartitioned) partitionColumns else Nil + } + + Seq(sortExpressions.map(SortOrder(_, Ascending))) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..6b16b410e3592 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter @@ -52,6 +52,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { fileSinkConf: FileSinkDesc, outputLocation: String, allColumns: Seq[Attribute], + bucketIdExpression: Option[Expression] = None, customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { @@ -83,7 +84,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, - bucketSpec = None, + bucketIdExpression = bucketIdExpression, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = Map.empty) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ab91727049ff5..efcd3f186f1e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.hive import java.io.File +import java.net.URI import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.catalyst.expressions.HiveHashFunction import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -511,44 +512,6 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } - testBucketedTable("INSERT should NOT fail if strict bucketing is NOT enforced") { - tableName => - withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "false") { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 4, 2 AS c, 3 AS b") - checkAnswer(sql(s"SELECT a, b, c, d FROM $tableName"), Row(1, 2, 3, 4)) - } - } - - testBucketedTable("INSERT should fail if strict bucketing / sorting is enforced") { - tableName => - withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "false") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") - } - } - withSQLConf("hive.enforce.bucketing" -> "false", "hive.enforce.sorting" -> "true") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") - } - } - withSQLConf("hive.enforce.bucketing" -> "true", "hive.enforce.sorting" -> "true") { - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName SELECT 1, 2, 3, 4") - } - } - } - - test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") { - // Set hive.exec.stagingdir under the table directory without start with ".". - withSQLConf("hive.exec.stagingdir" -> "./test") { - withTable("test_table") { - sql("CREATE TABLE test_table (key int)") - sql("INSERT OVERWRITE TABLE test_table SELECT 1") - checkAnswer(sql("SELECT * FROM test_table"), Row(1)) - } - } - } - test("insert overwrite to dir from hive metastore table") { withTempDir { dir => val path = dir.toURI.getPath @@ -750,4 +713,197 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } } + + private def validateBucketingAndSorting(numBuckets: Int, dir: URI): Unit = { + val bucketFiles = new File(dir).listFiles().filter(_.getName.startsWith("part-")) + .sortWith((x, y) => x.getName < y.getName) + assert(bucketFiles.length === numBuckets) + + bucketFiles.zipWithIndex.foreach { case(bucketFile, bucketId) => + val rows = spark.read.format("text").load(bucketFile.getAbsolutePath).collect() + var prevKey: Option[Int] = None + rows.foreach(row => { + val key = row.getString(0).split("\t")(0).toInt + assert(HiveHashFunction.hash(key, IntegerType, seed = 0) % numBuckets === bucketId) + + if (prevKey.isDefined) { + assert(prevKey.get <= key) + } + prevKey = Some(key) + }) + } + } + + test("Write data to a non-partitioned bucketed table") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + val session = spark.sessionState + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)).toDF("key", "value") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + val dir = session.catalog.defaultTablePath(session.sqlParser.parseTableIdentifier(tableName)) + validateBucketingAndSorting(numBuckets, dir) + } + } + + test("Write data to a bucketed table with static partition") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)) + .toDF("key", "value") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="val1", part2="val2") + |SELECT key, value + |FROM $sourceTableName + |""".stripMargin) + + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> "val1", "part2" -> "val2") + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + + test("Write data to a bucketed table with dynamic partitions") { + val numBuckets = 7 + val tableName = "bucketizedTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 1000) + .map(i => (i, i.toString, (if (i > 50) i % 2 else 2 - i % 2).toString, (i % 3).toString)) + .toDF("key", "value", "part1", "part2") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + (0 until 2).zip(0 until 3).foreach { case (part1, part2) => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1.toString, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + } + + test("Write data to a bucketed table with dynamic partitions (along with static partitions)") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + val part1StaticValue = "0" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString, (i % 3).toString)) + .toDF("key", "value", "part") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="$part1StaticValue", part2) + |SELECT key, value, part + |FROM $sourceTableName + |""".stripMargin) + + (0 until 3).foreach { case part2 => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1StaticValue, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + } + + test("Appends to bucketed table should NOT be allowed as it breaks bucketing guarantee") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until 100).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Append).insertInto(tableName) + } + assert(e.getMessage.contains("Appending data to hive bucketed table is not allowed")) + } + } + + test("Fail the query if number of files produced != number of buckets") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until (numBuckets / 2)).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Overwrite).insertInto(tableName) + } + assert(e.getMessage.contains("Potentially missing bucketed output files")) + } + } + + test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") { + // Set hive.exec.stagingdir under the table directory without start with ".". + withSQLConf("hive.exec.stagingdir" -> "./test") { + withTable("test_table") { + sql("CREATE TABLE test_table (key int)") + sql("INSERT OVERWRITE TABLE test_table SELECT 1") + checkAnswer(sql("SELECT * FROM test_table"), Row(1)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveBucketingSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveBucketingSuite.scala new file mode 100644 index 0000000000000..0584fc6ab2b09 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveBucketingSuite.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File +import java.net.URI + +import org.apache.spark.SparkException +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, HiveHash, HiveHashFunction} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.IntegerType + +/** + * Tests Spark's support for Hive bucketing + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL+BucketedTables + */ +class HiveBucketingSuite extends HiveComparisonTest with SQLTestUtils with TestHiveSingleton { + private val (table1, table2, table3) = ("table1", "table2", "table3") + private val (partitionedTable1, partitionedTable2) = ("partitionedTable1", "partitionedTable2") + + private val bucketSpec = Some(BucketSpec(8, Seq("key"), Seq("key", "value"))) + + override def beforeAll(): Unit = { + super.beforeAll() + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + createTable(table1, isPartitioned = false, bucketSpec) + createTable(table2, isPartitioned = false, bucketSpec) + createTable(table3, isPartitioned = false, None) + createTable(partitionedTable1, isPartitioned = true, bucketSpec) + createTable(partitionedTable2, isPartitioned = true, bucketSpec) + } + } + + override def afterAll(): Unit = { + try { + Seq(table1, table2, partitionedTable1, partitionedTable2) + .foreach(table => sql(s"DROP TABLE IF EXISTS $table")) + } finally { + super.afterAll() + } + } + + private def createTable( + tableName: String, + isPartitioned: Boolean, + bucketSpec: Option[BucketSpec]): Unit = { + + val bucketClause = + bucketSpec.map(b => + s"CLUSTERED BY (${b.bucketColumnNames.mkString(",")}) " + + s"SORTED BY (${b.sortColumnNames.map(_ + " ASC").mkString(" ,")}) " + + s"INTO ${b.numBuckets} buckets" + ).getOrElse("") + + val partitionClause = if (isPartitioned) "PARTITIONED BY(part STRING)" else "" + + sql(s""" + |CREATE TABLE IF NOT EXISTS $tableName (key int, value string) $partitionClause + |$bucketClause ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val data = 0 until 20 + val df = if (isPartitioned) { + (data.map(i => (i, i.toString, "part0")) ++ data.map(i => (i, i.toString, "part1"))) + .toDF("key", "value", "part") + } else { + data.map(i => (i, i.toString)).toDF("key", "value") + } + df.write.mode(SaveMode.Overwrite).insertInto(tableName) + } + + case class BucketedTableTestSpec( + bucketSpec: Option[BucketSpec] = bucketSpec, + expectShuffle: Boolean = false, + expectSort: Boolean = false) + + private def testBucketing( + queryString: String, + bucketedTableTestSpecLeft: BucketedTableTestSpec = BucketedTableTestSpec(), + bucketedTableTestSpecRight: BucketedTableTestSpec = BucketedTableTestSpec(), + bucketingEnabled: Boolean = true): Unit = { + + def validateChildNode( + child: SparkPlan, + bucketSpec: Option[BucketSpec], + expectedShuffle: Boolean, + expectedSort: Boolean): Unit = { + val exchange = child.find(_.isInstanceOf[ShuffleExchangeExec]) + assert(if (expectedShuffle) exchange.isDefined else exchange.isEmpty) + + val sort = child.find(_.isInstanceOf[SortExec]) + assert(if (expectedSort) sort.isDefined else sort.isEmpty) + + val tableScanNode = child.find(_.isInstanceOf[HiveTableScanExec]) + assert(tableScanNode.isDefined) + val tableScan = tableScanNode.get.asInstanceOf[HiveTableScanExec] + + bucketSpec match { + case Some(spec) if bucketingEnabled => + assert(tableScan.outputPartitioning.isInstanceOf[HashPartitioning]) + val part = tableScan.outputPartitioning.asInstanceOf[HashPartitioning] + assert(part.hashingFunctionClass == classOf[HiveHash]) + assert(part.numPartitions == spec.numBuckets) + + for ((columnName, index) <- spec.bucketColumnNames.zipWithIndex) { + part.expressions(index) match { + case a: AttributeReference => assert(a.name == columnName) + } + } + + import org.apache.spark.sql.catalyst.dsl.expressions._ + part.semanticEquals(HashPartitioning(Seq('key), spec.numBuckets, classOf[HiveHash])) + + val ordering = tableScan.outputOrdering + for ((columnName, sortOrder) <- spec.sortColumnNames.zip(ordering)) { + assert(sortOrder.direction === Ascending) + sortOrder.child match { + case a: AttributeReference => assert(a.name == columnName) + } + } + case _ => // do nothing + } + } + + val BucketedTableTestSpec(bucketSpecLeft, shuffleLeft, sortLeft) = bucketedTableTestSpecLeft + val BucketedTableTestSpec(bucketSpecRight, shuffleRight, sortRight) = bucketedTableTestSpecRight + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + SQLConf.BUCKETING_ENABLED.key -> bucketingEnabled.toString) { + + val plan = sql(queryString).queryExecution.executedPlan + val joinNode = plan.find(_.isInstanceOf[SortMergeJoinExec]) + assert(joinNode.isDefined) + val joinOperator = joinNode.get.asInstanceOf[SortMergeJoinExec] + + validateChildNode(joinOperator.left, bucketSpecLeft, shuffleLeft, sortLeft) + validateChildNode(joinOperator.right, bucketSpecRight, shuffleRight, sortRight) + } + } + + test("Join two bucketed tables with bucketing enabled") { + testBucketing(s"SELECT * FROM $table1 a JOIN $table2 b ON a.key = b.key") + } + + test("Join two bucketed tables with bucketing DISABLED") { + testBucketing( + s"SELECT * FROM $table1 a JOIN $table2 b ON a.key = b.key", + BucketedTableTestSpec(expectShuffle = true, expectSort = true), + BucketedTableTestSpec(expectShuffle = true, expectSort = true), + bucketingEnabled = false) + } + + test("Join a regular table with a bucketed table") { + testBucketing( + s"SELECT * FROM $table3 a JOIN $table1 b ON a.key = b.key", + BucketedTableTestSpec(bucketSpec = None, expectShuffle = true, expectSort = true)) + } + + test("Join two bucketed tables but not over their bucketing column(s)") { + testBucketing( + s"SELECT a.value, b.value FROM $table1 a JOIN $table2 b ON a.value = b.value", + BucketedTableTestSpec(expectShuffle = true, expectSort = true), + BucketedTableTestSpec(expectShuffle = true, expectSort = true)) + } + + test("Join where predicate has other clauses in addition to bucketing column(s)") { + testBucketing( + s"SELECT a.value, b.value " + + s"FROM $table1 a JOIN $table2 b " + + s"ON a.key = b.key AND a.value = b.value", + BucketedTableTestSpec(expectShuffle = true, expectSort = true), + BucketedTableTestSpec(expectShuffle = true, expectSort = true)) + } + + test("Join two partitioned tables (single partition scan) with bucketing enabled") { + testBucketing( + s""" + SELECT * FROM $partitionedTable1 a JOIN $partitionedTable2 b + ON a.key = b.key AND a.part = "part0" AND b.part = "part0" + """.stripMargin + ) + } + + test("Join two partitioned tables with single partition scanned on one side and multiple " + + "partition scanned on the other one") { + testBucketing( + s""" + SELECT * FROM $partitionedTable1 a JOIN $partitionedTable2 b + ON a.key = b.key AND b.part = "part0" + """.stripMargin, + BucketedTableTestSpec(expectSort = true) + ) + } + + test("Join partitioned tables with non-partitioned table (both bucketed)") { + testBucketing( + s""" + SELECT * FROM $table1 a JOIN $partitionedTable2 b ON a.key = b.key AND b.part = "part0" + """.stripMargin + ) + } + + test("Join two partitioned tables (multiple partitions scan) with bucketing enabled") { + testBucketing( + s"SELECT * FROM $partitionedTable1 a JOIN $partitionedTable2 b ON a.key = b.key", + BucketedTableTestSpec(expectSort = true), + BucketedTableTestSpec(expectSort = true) + ) + } + + private def validateBucketingAndSorting(numBuckets: Int, dir: URI): Unit = { + val bucketFiles = new File(dir).listFiles().filter(_.getName.startsWith("part-")) + .sortWith((x, y) => x.getName < y.getName) + assert(bucketFiles.length === numBuckets) + + bucketFiles.zipWithIndex.foreach { case(bucketFile, bucketId) => + val rows = spark.read.format("text").load(bucketFile.getAbsolutePath).collect() + var prevKey: Option[Int] = None + rows.foreach(row => { + val key = row.getString(0).split("\t")(0).toInt + assert(HiveHashFunction.hash(key, IntegerType, seed = 0) % numBuckets === bucketId) + + if (prevKey.isDefined) { + assert(prevKey.get <= key) + } + prevKey = Some(key) + }) + } + } + + test("Write data to a non-partitioned bucketed table") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + val session = spark.sessionState + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)).toDF("key", "value") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + val dir = session.catalog.defaultTablePath(session.sqlParser.parseTableIdentifier(tableName)) + validateBucketingAndSorting(numBuckets, dir) + } + } + + test("Write data to a bucketed table with static partition") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString)) + .toDF("key", "value") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="val1", part2="val2") + |SELECT key, value + |FROM $sourceTableName + |""".stripMargin) + + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> "val1", "part2" -> "val2") + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + + test("Write data to a bucketed table with dynamic partitions") { + val numBuckets = 7 + val tableName = "bucketizedTable" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString, (if (i > 50) i % 2 else 2 - i % 2).toString, (i % 3).toString)) + .toDF("key", "value", "part1", "part2") + .write.mode(SaveMode.Overwrite).insertInto(tableName) + + (0 until 2).zip(0 until 3).foreach { case (part1, part2) => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1.toString, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + } + + test("Write data to a bucketed table with dynamic partitions (along with static partitions)") { + val numBuckets = 8 + val tableName = "bucketizedTable" + val sourceTableName = "sourceTable" + val part1StaticValue = "0" + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(tableName, sourceTableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |PARTITIONED BY(part1 STRING, part2 STRING) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + (0 until 100) + .map(i => (i, i.toString, (i % 3).toString)) + .toDF("key", "value", "part") + .createOrReplaceTempView(sourceTableName) + + sql(s""" + |INSERT OVERWRITE TABLE $tableName PARTITION(part1="$part1StaticValue", part2) + |SELECT key, value, part + |FROM $sourceTableName + |""".stripMargin) + + (0 until 3).foreach { case part2 => + val dir = spark.sessionState.catalog.getPartition( + spark.sessionState.sqlParser.parseTableIdentifier(tableName), + Map("part1" -> part1StaticValue, "part2" -> part2.toString) + ).location + + validateBucketingAndSorting(numBuckets, dir) + } + } + } + } + + test("Appends to bucketed table should NOT be allowed as it breaks bucketing guarantee") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until 100).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Append).insertInto(tableName) + } + assert(e.getMessage.contains("Appending data to hive bucketed table is not allowed")) + } + } + + test("Fail the query if number of files produced != number of buckets") { + val numBuckets = 8 + val tableName = "nonPartitionedBucketed" + + withTable(tableName) { + sql(s""" + |CREATE TABLE $tableName (key int, value string) + |CLUSTERED BY (key) SORTED BY (key ASC) into $numBuckets buckets + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |""".stripMargin) + + val df = (0 until (numBuckets / 2)).map(i => (i, i.toString)).toDF("key", "value") + val e = intercept[SparkException] { + df.write.mode(SaveMode.Overwrite).insertInto(tableName) + } + assert(e.getMessage.contains("Potentially missing bucketed output files")) + } + } +}