Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21649][SQL] Support writing data into hive bucket table. #18866

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
*/
private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)

private var fileNameWithPartitionId: Boolean = true

protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
val format = context.getOutputFormatClass.newInstance()
// If OutputFormat is Configurable, we should set conf to it.
Expand Down Expand Up @@ -103,7 +105,12 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
f"part-$split%05d-$jobId$ext"
if (fileNameWithPartitionId) {
f"part-$split%05d-$jobId$ext"
} else {
// File names created by different tasks should have different `ext` when `split` is not used.
f"part-$jobId$ext"
}
}

override def setupJob(jobContext: JobContext): Unit = {
Expand All @@ -118,6 +125,8 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString)
jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true)
jobContext.getConfiguration.setInt("mapreduce.task.partition", 0)
fileNameWithPartitionId =
jobContext.getConfiguration.getBoolean("spark.sql.bucket.fileNameWithPartitionId", true)

val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId)
committer = setupCommitter(taskAttemptContext)
Expand All @@ -126,7 +135,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)

override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
committer.commitJob(jobContext)
val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
val filesToMove = taskCommits.map(_.obj.asInstanceOf[HadoopMRTaskCommitStatus].absPathFiles)
.foldLeft(Map[String, String]())(_ ++ _)
logDebug(s"Committing files staged for absolute locations $filesToMove")
val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
Expand All @@ -152,7 +161,22 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
val attemptId = taskContext.getTaskAttemptID
SparkHadoopMapRedUtil.commitTask(
committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
new TaskCommitMessage(addedAbsPathFiles.toMap)
val committedPaths = mutable.HashSet[String]()
committer match {
case fileOutputCommitter: FileOutputCommitter =>
val committedPath = fileOutputCommitter.getCommittedTaskPath(taskContext)
if (committedPath != null) {
committedPaths += committedPath.toString
}
case _ =>
committedPaths += path
}
if (path != null) {
committedPaths += absPathStagingDir.toString
}

new TaskCommitMessage(
HadoopMRTaskCommitStatus(addedAbsPathFiles.toMap, committedPaths.toSeq))
}

override def abortTask(taskContext: TaskAttemptContext): Unit = {
Expand All @@ -164,3 +188,5 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
}
}
}

case class HadoopMRTaskCommitStatus(absPathFiles: Map[String, String], commitPaths: Seq[String])
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ object CatalogUtils {
tableCols: Seq[String],
bucketSpec: BucketSpec,
resolver: Resolver): BucketSpec = {
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, isHiveBucket) = bucketSpec
val normalizedBucketCols = bucketColumnNames.map { colName =>
normalizeColumnName(tableName, tableCols, colName, "bucket", resolver)
}
val normalizedSortCols = sortColumnNames.map { colName =>
normalizeColumnName(tableName, tableCols, colName, "sort", resolver)
}
BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)
BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols, isHiveBucket)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,13 @@ case class CatalogTablePartition(
* @param numBuckets number of buckets.
* @param bucketColumnNames the names of the columns that used to generate the bucket id.
* @param sortColumnNames the names of the columns that used to sort data in each bucket.
* @param isHiveBucket if the spec is for Hive bucket table.
*/
case class BucketSpec(
numBuckets: Int,
bucketColumnNames: Seq[String],
sortColumnNames: Seq[String]) {
sortColumnNames: Seq[String],
isHiveBucket: Boolean = false) {
if (numBuckets <= 0 || numBuckets >= 100000) {
throw new AnalysisException(
s"Number of buckets should be greater than 0 but less than 100000. Got `$numBuckets`")
Expand All @@ -172,7 +174,12 @@ case class BucketSpec(
} else {
""
}
s"$numBuckets buckets, $bucketString$sortString"
val isHiveBucketString = if (isHiveBucket) {
", it is hive bucket."
} else {
", it is not hive bucket."
}
s"$numBuckets buckets, $bucketString$sortString$isHiveBucketString"
}

def toLinkedHashMap: mutable.LinkedHashMap[String, String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ case object AllTuples extends Distribution
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located. Based on the context, this
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
* within a single partition. `clusterOpt` indicates the numbers of partitions. `useHiveHash`
* tells if Hive hash should be used when do partitioning.
*/
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
case class ClusteredDistribution(clustering: Seq[Expression], clustersOpt: Option[Int] = None,
useHiveHash: Boolean = false) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
Expand Down Expand Up @@ -234,7 +236,8 @@ case object SinglePartition extends Partitioning {
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int,
useHiveHash: Boolean = false)
extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
Expand All @@ -243,7 +246,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, clustersOpt, clusteredByHiveHash)
if (clustersOpt.forall(_ == numPartitions) && clusteredByHiveHash == useHiveHash) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
Expand All @@ -260,9 +264,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

/**
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
* than numPartitions) based on hashing expressions.
* than numPartitions) based on hashing expressions. `HiveHash` will be returned when
* `useHiveHash` is true. This is for compatibility when insert data into Hive bucket table.
*/
def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
def partitionIdExpression: Expression =
if (useHiveHash) {
Pmod(new HiveHash(expressions), Literal(numPartitions))
} else {
Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
}
}

/**
Expand All @@ -289,7 +299,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, clustersOpt, _)
if clustersOpt.forall(_ == numPartitions) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ class DistributionSuite extends SparkFunSuite {
ClusteredDistribution(Seq('d, 'e)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b), 10),
ClusteredDistribution(Seq('a, 'b), Some(10)),
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b), 10),
ClusteredDistribution(Seq('a, 'b), Some(5)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b), 10, useHiveHash = true),
ClusteredDistribution(Seq('a, 'b), Some(10), useHiveHash = true),
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b), 10, useHiveHash = false),
ClusteredDistribution(Seq('a, 'b), Some(10), useHiveHash = true),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b), 10, useHiveHash = true),
ClusteredDistribution(Seq('a, 'b), Some(10), useHiveHash = false),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
AllTuples,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@ class PartitioningSuite extends SparkFunSuite {
assert(partitioningA.guarantees(partitioningA))
assert(partitioningA.compatibleWith(partitioningA))
}

test("HashPartitioning compatibility should be sensitive to whether Hive hash is used.") {
val expressions = Seq(Literal(2), Literal(3))
val partitioningA = HashPartitioning(expressions, 100, useHiveHash = false)
val partitioningB = HashPartitioning(expressions, 100, useHiveHash = true)
assert(!partitioningA.compatibleWith(partitioningB))
assert(!partitioningA.guarantees(partitioningB))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,13 @@ class TreeNodeSuite extends SparkFunSuite {

// Converts BucketSpec to JSON
assertJSON(
BucketSpec(1, Seq("bucket"), Seq("sort")),
BucketSpec(1, Seq("bucket"), Seq("sort"), true),
JObject(
"product-class" -> classOf[BucketSpec].getName,
"numBuckets" -> 1,
"bucketColumnNames" -> "[bucket]",
"sortColumnNames" -> "[sort]"))
"sortColumnNames" -> "[sort]",
"isHiveBucket" -> true))

// Converts WindowFrame to JSON
assertJSON(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil)
val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)
val selectQuery = Option(ctx.query).map(plan)
val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec)
val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec).map(_.copy(isHiveBucket = true))

// Note: Hive requires partition columns to be distinct from the schema, so we need
// to include the partition columns here explicitly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
Expand All @@ -50,6 +51,8 @@ trait RunnableCommand extends logical.Command {
def run(sparkSession: SparkSession): Seq[Row] = {
throw new NotImplementedError
}

def requiredDestribution: Option[Seq[Distribution]] = None
}

/**
Expand Down Expand Up @@ -97,6 +100,11 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e
protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}

/** Specifies any partition requirements on the input data for this operator. */
override def requiredChildDistribution: Seq[Distribution] = {
cmd.requiredDestribution.getOrElse(super.requiredChildDistribution)
}
}

/**
Expand Down
Loading