Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26188][SQL] FileIndex: don't infer data types of partition columns if user specifies schema #23165

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,33 +126,15 @@ abstract class PartitioningAwareFileIndex(
val caseInsensitiveOptions = CaseInsensitiveMap(parameters)
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
val inferredPartitionSpec = PartitioningUtils.parsePartitions(

val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.parsePartitions(
leafDirs,
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
basePaths = basePaths,
userSpecifiedSchema = userSpecifiedSchema,
caseSensitive = caseSensitive,
timeZoneId = timeZoneId)
userSpecifiedSchema match {
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
val userPartitionSchema =
combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove combineInferredAndUserSpecifiedPartitionSchema now


// we need to cast into the data type that user specified.
def castPartitionValuesToUserSchema(row: InternalRow) = {
InternalRow((0 until row.numFields).map { i =>
val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType
Cast(
Literal.create(row.get(i, dt), dt),
userPartitionSchema.fields(i).dataType,
Option(timeZoneId)).eval()
}: _*)
}

PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part =>
part.copy(values = castPartitionValuesToUserSchema(part.values))
})
case _ =>
inferredPartitionSpec
}
}

private def prunePartitions(
Expand Down Expand Up @@ -233,25 +215,6 @@ abstract class PartitioningAwareFileIndex(
val name = path.getName
!((name.startsWith("_") && !name.contains("=")) || name.startsWith("."))
}

/**
* In the read path, only managed tables by Hive provide the partition columns properly when
* initializing this class. All other file based data sources will try to infer the partitioning,
* and then cast the inferred types to user specified dataTypes if the partition columns exist
* inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
* inconsistent data types as reported in SPARK-21463.
* @param spec A partition inference result
* @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
*/
private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = {
val equality = sparkSession.sessionState.conf.resolver
val resolved = spec.partitionColumns.map { partitionField =>
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
}
StructType(resolved)
}
}

object PartitioningAwareFileIndex {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -94,18 +94,34 @@ object PartitioningUtils {
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZoneId: String): PartitionSpec = {
parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId))
parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema,
caseSensitive, DateTimeUtils.getTimeZone(timeZoneId))
}

private[datasources] def parsePartitions(
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZone: TimeZone): PartitionSpec = {
val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we build this at the caller side out of PartitioningUtils? Then we only need one extra parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer to make the parameter simple and easy to understand. So that the logic of caller(outside the PartitioningUtils) looks cleaner.

val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
if (!caseSensitive) {
CaseInsensitiveMap(nameToDataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this if !caseSensitive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for pointing it out :)

} else {
nameToDataType
}
} else {
Map.empty[String, DataType]
}

// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, timeZone)
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone)
}.unzip

// We create pairs of (path -> path's partition value) here
Expand Down Expand Up @@ -147,13 +163,13 @@ object PartitioningUtils {
columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
// We always assume partition columns are nullable since we've no idea whether null values
// will be appended in the future.
StructField(name, dataType, nullable = true)
StructField(name, userSpecifiedDataTypes.getOrElse(name, dataType), nullable = true)
}
}

// Finally, we create `Partition`s based on paths and resolved partition values.
val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
case (PartitionValues(_, literals), (path, _)) =>
case (PartitionValues(columnNames, literals), (path, _)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change?

PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path)
}

Expand Down Expand Up @@ -185,6 +201,7 @@ object PartitioningUtils {
path: Path,
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
Expand All @@ -206,7 +223,7 @@ object PartitioningUtils {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, timeZone)
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone)
maybeColumn.foreach(columns += _)

// Now, we determine if we should stop.
Expand Down Expand Up @@ -239,6 +256,7 @@ object PartitioningUtils {
private def parsePartitionColumn(
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
Expand All @@ -250,7 +268,16 @@ object PartitioningUtils {
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")

val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
val literal = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone)
val castedValue =
Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval()
Literal.create(castedValue, userSpecifiedDataTypes(columnName))
} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
}
Some(columnName -> literal)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}

class FileIndexSuite extends SharedSQLContext {
Expand All @@ -49,6 +50,21 @@ class FileIndexSuite extends SharedSQLContext {
}
}

test("SPARK-26188: don't infer data types of partition columns if user specifies schema") {
withTempDir { dir =>
val partitionDirectory = new File(dir, s"a=4d")
partitionDirectory.mkdir()
val file = new File(partitionDirectory, "text.txt")
stringToFile(file, "text")
val path = new Path(dir.getCanonicalPath)
val schema = StructType(Seq(StructField("a", StringType, false)))
val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
val partitionValues = fileIndex.partitionSpec().partitions.map(_.values)
assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 &&
partitionValues(0).getString(0) == "4d")
}
}

test("InMemoryFileIndex: input paths are converted to qualified paths") {
withTempDir { dir =>
val file = new File(dir, "text.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
"hdfs://host:9000/path/a=10.5/b=hello")

var exception = intercept[AssertionError] {
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))

Expand All @@ -115,6 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -128,6 +130,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/something=true/table")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -141,6 +145,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/table=true")),
None,
true,
timeZoneId)

// Invalid
Expand All @@ -154,6 +160,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
Expand All @@ -174,20 +182,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/tmp/tables/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
}

test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1
val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1
assert(expected === actual)
}

def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
parsePartition(new Path(path), true, Set.empty[Path], timeZone)
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)
}.getMessage

assert(message.contains(expected))
Expand Down Expand Up @@ -231,6 +241,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec1.isEmpty)
Expand All @@ -240,6 +251,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec2 ==
Expand All @@ -258,6 +270,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
rootPaths,
None,
true,
timeZoneId)
assert(actualSpec.partitionColumns === spec.partitionColumns)
assert(actualSpec.partitions.length === spec.partitions.length)
Expand Down Expand Up @@ -370,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partitions with type inference disabled") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
val actualSpec =
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId)
assert(actualSpec === spec)
}

Expand Down