Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19256][SQL] Remove ordering enforcement from FileFormatWriter and let planner do that #20206

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.FileFormatWriter
Expand Down Expand Up @@ -60,5 +61,9 @@ trait DataWritingCommand extends Command {
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
}

def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution)

def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -112,6 +113,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

override def nodeName: String = "Execute " + cmd.nodeName

override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution

override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object FileFormatWriter extends Logging {
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
bucketIdExpression: Option[Expression],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String])
: Set[String] = {
Expand All @@ -122,17 +122,6 @@ object FileFormatWriter extends Logging {
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)

val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val caseInsensitiveOptions = CaseInsensitiveMap(options)

// Note: prepareWrite has side effect. It sets "job".
Expand All @@ -156,40 +145,14 @@ object FileFormatWriter extends Logging {
statsTrackers = statsTrackers
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = plan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

SQLExecution.checkSQLExecutionId(sparkSession)

// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val rdd = if (orderingMatched) {
plan.execute()
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This concern is still valid, the DataWritingCommand.requiredChildOrdering is based on logical plan's output attribute ids, how can we safely apply it in DataWritingCommandExec?

val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending))
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
SortExec(
orderingExpr,
global = false,
child = plan).execute()
}
val rdd = plan.execute()
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
Expand All @@ -202,7 +165,7 @@ object FileFormatWriter extends Logging {
committer,
iterator = iter)
},
0 until rdd.partitions.length,
rdd.partitions.indices,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
Expand Down Expand Up @@ -521,18 +484,18 @@ object FileFormatWriter extends Logging {
var recordsInFile: Long = 0L
var fileCounter = 0
val updatedPartitions = mutable.Set[String]()
var currentPartionValues: Option[UnsafeRow] = None
var currentPartitionValues: Option[UnsafeRow] = None
var currentBucketId: Option[Int] = None

for (row <- iter) {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None

if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
if (isPartitioned && currentPartitionValues != nextPartitionValues) {
currentPartitionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartitionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
Expand All @@ -543,7 +506,7 @@ object FileFormatWriter extends Logging {
fileCounter = 0

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
} else if (desc.maxRecordsPerFile > 0 &&
recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
Expand All @@ -554,7 +517,7 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
}
val outputRow = getOutputRow(row)
currentWriter.write(outputRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression,
SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
Expand Down Expand Up @@ -150,6 +152,10 @@ case class InsertIntoHadoopFsRelationCommand(
}
}

val partitionSet = AttributeSet(partitionColumns)
val dataColumns = query.output.filterNot(partitionSet.contains)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use outputColumns instead of query.output, cc @gengliangwang

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, it should be outputColumns here, which is the output columns of analyzed plan. See #20020 for details.

val bucketIdExpression = getBucketIdExpression(dataColumns)

val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
Expand All @@ -160,7 +166,7 @@ case class InsertIntoHadoopFsRelationCommand(
qualifiedOutputPath.toString, customPartitionLocations, outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
bucketIdExpression = bucketIdExpression,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options)

Expand All @@ -184,6 +190,43 @@ case class InsertIntoHadoopFsRelationCommand(
Seq.empty[Row]
}

private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = {
bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
}

/**
* How is `requiredOrdering` determined ?
*
* table type | requiredOrdering
* -----------------+-------------------------------------------------
* normal table | partition columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: non-bucketed table, a partitioned table is not a normal table...

* bucketed table | (partition columns + bucketId + sort columns)
* -----------------+-------------------------------------------------
*/
override def requiredOrdering: Seq[Seq[SortOrder]] = {
val sortExpressions = bucketSpec match {
case Some(spec) =>
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = query.output.filterNot(partitionSet.contains)
val bucketIdExpression = getBucketIdExpression(dataColumns)
val sortColumns = spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
partitionColumns ++ bucketIdExpression ++ sortColumns

case _ => partitionColumns
}
if (sortExpressions.nonEmpty) {
Seq(sortExpressions.map(SortOrder(_, Ascending)))
} else {
Seq.fill(children.size)(Nil)
}
}

/**
* Deletes all partition files that match the specified static prefix. Partitions with custom
* locations are also cleared based on the custom locations map given to this class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class FileStreamSink(
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
bucketIdExpression = None,
statsTrackers = Nil,
options = options)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalog}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.CommandUtils
Expand Down Expand Up @@ -71,6 +71,29 @@ case class InsertIntoHiveTable(
ifPartitionNotExists: Boolean,
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {


/**
* For partitioned tables, `requiredOrdering` is over static partition columns of table
*/
override def requiredOrdering: Seq[Seq[SortOrder]] = {
if (table.partitionColumnNames.nonEmpty) {
val numDynamicPartitions = partition.values.count(_.isEmpty)
val partitionAttributes = table.partitionColumnNames.takeRight(numDynamicPartitions).map {
name =>
query.resolve(
name :: Nil,
SparkSession.getActiveSession.get.sessionState.analyzer.resolver
).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}
Seq(partitionAttributes.map(SortOrder(_, Ascending)))
} else {
Seq.fill(children.size)(Nil)
}
}

/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
* `org.apache.hadoop.hive.serde2.SerDe` and the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand {
FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionAttributes,
bucketSpec = None,
bucketIdExpression = None,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = Map.empty)
}
Expand Down