diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..69bbce086d3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -21,8 +21,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter @@ -60,5 +61,9 @@ trait DataWritingCommand extends Command { new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) } + def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + + def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2cc0e38adc2ee..3030cc9fef922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -112,6 +113,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) override def nodeName: String = "Execute " + cmd.nodeName + override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..cec901d8c9be7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -109,7 +109,7 @@ object FileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], - bucketSpec: Option[BucketSpec], + bucketIdExpression: Option[Expression], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String]) : Set[String] = { @@ -122,17 +122,6 @@ object FileFormatWriter extends Logging { val partitionSet = AttributeSet(partitionColumns) val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains) - val bucketIdExpression = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } - val caseInsensitiveOptions = CaseInsensitiveMap(options) // Note: prepareWrite has side effect. It sets "job". @@ -156,19 +145,6 @@ object FileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns - // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } - } - SQLExecution.checkSQLExecutionId(sparkSession) // This call shouldn't be put into the `try` block below because it only initializes and @@ -176,20 +152,7 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - plan.execute() - } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)) - .map(BindReferences.bindReference(_, outputSpec.outputColumns)) - SortExec( - orderingExpr, - global = false, - child = plan).execute() - } + val rdd = plan.execute() val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( rdd, @@ -202,7 +165,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rdd.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res @@ -521,18 +484,18 @@ object FileFormatWriter extends Logging { var recordsInFile: Long = 0L var fileCounter = 0 val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None + var currentPartitionValues: Option[UnsafeRow] = None var currentBucketId: Option[Int] = None for (row <- iter) { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) } if (isBucketed) { currentBucketId = nextBucketId @@ -543,7 +506,7 @@ object FileFormatWriter extends Logging { fileCounter = 0 releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } else if (desc.maxRecordsPerFile > 0 && recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. @@ -554,7 +517,7 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) + newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions) } val outputRow = getOutputRow(row) currentWriter.write(outputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index dd7ef0d15c140..45b0d36c6c6f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -25,8 +25,10 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, +SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -150,6 +152,10 @@ case class InsertIntoHadoopFsRelationCommand( } } + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, @@ -160,7 +166,7 @@ case class InsertIntoHadoopFsRelationCommand( qualifiedOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = bucketSpec, + bucketIdExpression = bucketIdExpression, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options) @@ -184,6 +190,43 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + } + + /** + * How is `requiredOrdering` determined ? + * + * table type | requiredOrdering + * -----------------+------------------------------------------------- + * normal table | partition columns + * bucketed table | (partition columns + bucketId + sort columns) + * -----------------+------------------------------------------------- + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + val sortExpressions = bucketSpec match { + case Some(spec) => + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val bucketIdExpression = getBucketIdExpression(dataColumns) + val sortColumns = spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + partitionColumns ++ bucketIdExpression ++ sortColumns + + case _ => partitionColumns + } + if (sortExpressions.nonEmpty) { + Seq(sortExpressions.map(SortOrder(_, Ascending))) + } else { + Seq.fill(children.size)(Nil) + } + } + /** * Deletes all partition files that match the specified static prefix. Partitions with custom * locations are also cleared based on the custom locations map given to this class. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..bdc99a4085f0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -128,7 +128,7 @@ class FileStreamSink( outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, - bucketSpec = None, + bucketIdExpression = None, statsTrackers = Nil, options = options) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..755283767aa42 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalog} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils @@ -71,6 +71,29 @@ case class InsertIntoHiveTable( ifPartitionNotExists: Boolean, outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + + /** + * For partitioned tables, `requiredOrdering` is over static partition columns of table + */ + override def requiredOrdering: Seq[Seq[SortOrder]] = { + if (table.partitionColumnNames.nonEmpty) { + val numDynamicPartitions = partition.values.count(_.isEmpty) + val partitionAttributes = table.partitionColumnNames.takeRight(numDynamicPartitions).map { + name => + query.resolve( + name :: Nil, + SparkSession.getActiveSession.get.sessionState.analyzer.resolver + ).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]") + }.asInstanceOf[Attribute] + } + Seq(partitionAttributes.map(SortOrder(_, Ascending))) + } else { + Seq.fill(children.size)(Nil) + } + } + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..f3877abac8110 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -83,7 +83,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, - bucketSpec = None, + bucketIdExpression = None, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = Map.empty) }