From b5eae59c5548a22cf9a1d67115e59600fb6cb9d4 Mon Sep 17 00:00:00 2001 From: jackierwzhang Date: Fri, 18 Feb 2022 23:12:05 +0800 Subject: [PATCH] [SPARK-38094] Enable matching schema column names by field ids ### What changes were proposed in this pull request? Field Id is a native field in the Parquet schema (https://github.com/apache/parquet-format/blob/master/src/main/thrift/parquet.thrift#L398) After this PR, when the requested schema has field IDs, Parquet readers will first use the field ID to determine which Parquet columns to read if the field ID exists in Spark schema, before falling back to match using column names. This PR supports: - Vectorized reader - parquet-mr reader ### Why are the changes needed? It enables matching columns by field id for supported DWs like iceberg and Delta. Specifically, it enables easy conversion from Iceberg (which uses field ids by name) to Delta, and allows `id` mode for Delta [column mapping](https://docs.databricks.com/delta/delta-column-mapping.html) ### Does this PR introduce _any_ user-facing change? This PR introduces three new configurations: `spark.sql.parquet.fieldId.write.enabled`: If enabled, Spark will write out native field ids that are stored inside StructField's metadata as `parquet.field.id` to parquet files. This configuration is default to `true`. `spark.sql.parquet.fieldId.read.enabled`: If enabled, Spark will attempt to read field ids in parquet files and utilize them for matching columns. This configuration is default to `false`, so Spark could maintain its existing behavior by default. `spark.sql.parquet.fieldId.read.ignoreMissing`: if enabled, Spark will read parquet files that do not have any field ids, while attempting to match the columns by id in Spark schema; nulls will be returned for spark columns without a match. This configuration is default to `false`, so Spark could alert the user in case field id matching is expected but parquet files do not have any ids. ### How was this patch tested? Existing tests + new unit tests. Closes #35385 from jackierwzhang/SPARK-38094-field-ids. Authored-by: jackierwzhang Signed-off-by: Wenchen Fan --- .../sql/errors/QueryExecutionErrors.scala | 9 + .../apache/spark/sql/internal/SQLConf.scala | 33 ++ .../parquet/ParquetFileFormat.scala | 4 + .../parquet/ParquetReadSupport.scala | 249 +++++++-- .../parquet/ParquetRowConverter.scala | 34 +- .../parquet/ParquetSchemaConverter.scala | 18 +- .../datasources/parquet/ParquetUtils.scala | 44 +- .../datasources/v2/parquet/ParquetWrite.scala | 4 + .../parquet/ParquetFieldIdIOSuite.scala | 213 +++++++ .../parquet/ParquetFieldIdSchemaSuite.scala | 528 ++++++++++++++++++ .../parquet/ParquetSchemaSuite.scala | 10 +- .../datasources/parquet/ParquetTest.scala | 12 +- .../spark/sql/test/TestSQLContext.scala | 8 +- 13 files changed, 1088 insertions(+), 78 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 1d87e9f0a992b..d1db0177dfd23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -808,6 +808,15 @@ object QueryExecutionErrors { """.stripMargin.replaceAll("\n", " ")) } + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): Throwable = { + new RuntimeException( + s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { new SparkException(s"Failed to merge incompatible schemas $left and $right", e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 59a896a29b6f2..3a7ce650ea633 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,6 +934,33 @@ object SQLConf { .intConf .createWithDefault(4096) + val PARQUET_FIELD_ID_WRITE_ENABLED = + buildConf("spark.sql.parquet.fieldId.write.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, " + + "Parquet writers will populate the field Id " + + "metadata (if present) in the Spark schema to the Parquet schema.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_FIELD_ID_READ_ENABLED = + buildConf("spark.sql.parquet.fieldId.read.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " + + "will use field IDs (if present) in the requested Spark schema to look up Parquet " + + "fields instead of using column names") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val IGNORE_MISSING_PARQUET_FIELD_ID = + buildConf("spark.sql.parquet.fieldId.read.ignoreMissing") + .doc("When the Parquet file doesn't have any field IDs but the " + + "Spark read schema is using field IDs to read, we will silently return nulls " + + "when this flag is enabled, or error otherwise.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -4253,6 +4280,12 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) + + def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED) + + def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index aa6f9ee91656d..18876dedb951e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -119,6 +119,10 @@ class ParquetFileFormat SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index bdab0f7892f00..97e691ff7c66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId -import java.util.{Locale, Map => JMap} +import java.util +import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -85,13 +86,71 @@ class ParquetReadSupport( StructType.fromString(schemaString) } + val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( + context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) + new ReadContext(parquetRequestedSchema, new util.HashMap[String, String]()) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new ParquetRecordMaterializer( + parquetRequestedSchema, + ParquetReadSupport.expandUDT(catalystRequestedSchema), + new ParquetToSparkSchemaConverter(conf), + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + } +} + +object ParquetReadSupport extends Logging { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + def generateFakeColumnName: String = s"_fake_name_${UUID.randomUUID()}" + + def getRequestedSchema( + parquetFileSchema: MessageType, + catalystRequestedSchema: StructType, + conf: Configuration, + enableVectorizedReader: Boolean): MessageType = { val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val parquetFileSchema = context.getFileSchema + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) + val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, + SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) + + if (!ignoreMissingIds && + !containsFieldIds(parquetFileSchema) && + ParquetUtils.hasFieldIds(catalystRequestedSchema)) { + throw new RuntimeException( + "Spark read schema expects field Ids, " + + "but Parquet file schema doesn't contain any field Ids.\n" + + "Please remove the field ids from Spark schema or ignore missing ids by " + + "setting `spark.sql.parquet.fieldId.ignoreMissing = true`\n" + + s""" + |Spark read schema: + |${catalystRequestedSchema.prettyJson} + | + |Parquet file schema: + |${parquetFileSchema.toString} + |""".stripMargin) + } val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, caseSensitive) + catalystRequestedSchema, caseSensitive, useFieldId) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -109,6 +168,7 @@ class ParquetReadSupport( // in parquetRequestedSchema which are not present in the file. parquetClippedSchema } + logDebug( s"""Going to read the following fields from the Parquet file with the following schema: |Parquet file schema: @@ -120,34 +180,20 @@ class ParquetReadSupport( |Catalyst requested schema: |${catalystRequestedSchema.treeString} """.stripMargin) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + + parquetRequestedSchema } /** - * Called on executor side after [[init()]], before instantiating actual Parquet record readers. - * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet - * records to Catalyst [[InternalRow]]s. + * Overloaded method for backward compatibility with + * `caseSensitive` default to `true` and `useFieldId` default to `false` */ - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - val parquetRequestedSchema = readContext.getRequestedSchema - new ParquetRecordMaterializer( - parquetRequestedSchema, - ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf), - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + clipParquetSchema(parquetSchema, catalystSchema, caseSensitive, useFieldId = false) } -} - -object ParquetReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist @@ -156,9 +202,10 @@ object ParquetReadSupport { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, - caseSensitive: Boolean = true): MessageType = { + caseSensitive: Boolean, + useFieldId: Boolean): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, caseSensitive, useFieldId) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -170,26 +217,36 @@ object ParquetReadSupport { } private def clipParquetType( - parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { - catalystType match { + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { + val newParquetType = catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + clipParquetMapType( + parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } + + if (useFieldId && parquetType.getId != null) { + newParquetType.withId(parquetType.getId.intValue()) + } else { + newParquetType + } } /** @@ -210,7 +267,10 @@ object ParquetReadSupport { * [[StructType]]. */ private def clipParquetListType( - parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -218,7 +278,7 @@ object ParquetReadSupport { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, caseSensitive) + clipParquetType(parquetList, elementType, caseSensitive, useFieldId) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -250,19 +310,28 @@ object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive, useFieldId)) .named(parquetList.getName) } else { + val newRepeatedGroup = Types + .repeatedGroup() + .addField( + clipParquetType( + repeatedGroup.getType(0), elementType, caseSensitive, useFieldId)) + .named(repeatedGroup.getName) + + val newElementType = if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + // Otherwise, the repeated field's type is the element type with the repeated field's // repetition. Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField( - Types - .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) - .named(repeatedGroup.getName)) + .addField(newElementType) .named(parquetList.getName) } } @@ -277,7 +346,8 @@ object ParquetReadSupport { parquetMap: GroupType, keyType: DataType, valueType: DataType, - caseSensitive: Boolean): GroupType = { + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -285,13 +355,19 @@ object ParquetReadSupport { val parquetKeyType = repeatedGroup.getType(0) val parquetValueType = repeatedGroup.getType(1) - val clippedRepeatedGroup = - Types + val clippedRepeatedGroup = { + val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) + if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + } Types .buildGroup(parquetMap.getRepetition) @@ -309,8 +385,12 @@ object ParquetReadSupport { * pruning. */ private def clipParquetGroup( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { + val clippedParquetFields = + clipParquetGroupFields(parquetRecord, structType, caseSensitive, useFieldId) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) @@ -324,23 +404,29 @@ object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { - val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - if (caseSensitive) { - val caseSensitiveParquetFieldMap = + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): Seq[Type] = { + val toParquet = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, useFieldId = useFieldId) + lazy val caseSensitiveParquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - structType.map { f => - caseSensitiveParquetFieldMap + lazy val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + lazy val idToParquetFieldMap = + parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f => f.getId.intValue()) + + def matchCaseSensitiveField(f: StructField): Type = { + caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, caseSensitive)) + .map(clipParquetType(_, f.dataType, caseSensitive, useFieldId)) .getOrElse(toParquet.convertField(f)) - } - } else { + } + + def matchCaseInsensitiveField(f: StructField): Type = { // Do case-insensitive resolution only if in case-insensitive mode - val caseInsensitiveParquetFieldMap = - parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) - structType.map { f => - caseInsensitiveParquetFieldMap + caseInsensitiveParquetFieldMap .get(f.name.toLowerCase(Locale.ROOT)) .map { parquetTypes => if (parquetTypes.size > 1) { @@ -349,9 +435,39 @@ object ParquetReadSupport { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) } }.getOrElse(toParquet.convertField(f)) + } + + def matchIdField(f: StructField): Type = { + val fieldId = ParquetUtils.getFieldId(f) + idToParquetFieldMap + .get(fieldId) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( + fieldId, parquetTypesString) + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) + } + }.getOrElse { + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) + } + } + + val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType) + structType.map { f => + if (shouldMatchById && ParquetUtils.hasFieldId(f)) { + matchIdField(f) + } else if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) } } } @@ -410,4 +526,13 @@ object ParquetReadSupport { expand(schema).asInstanceOf[StructType] } + + /** + * Whether the parquet schema contains any field IDs. + */ + def containsFieldIds(schema: Type): Boolean = schema match { + case p: PrimitiveType => p.getId != null + // We don't require all fields to have IDs, so we use `exists` here. + case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b12898360dcf4..63ad5ed6db82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -203,16 +203,38 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { - catalystType.fieldNames.zipWithIndex.toMap + def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap + + val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { + nameToIndex } else { - CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + CaseInsensitiveMap(nameToIndex) } + + // (SPARK-38094) parquet field ids, if exist, should be prioritized for matching + val catalystFieldIdxByFieldId = + if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(catalystType)) { + catalystType.fields + .zipWithIndex + .filter { case (f, _) => ParquetUtils.hasFieldId(f) } + .map { case (f, idx) => (ParquetUtils.getFieldId(f), idx) } + .toMap + } else { + Map.empty[Int, Int] + } + parquetType.getFields.asScala.map { parquetField => - val fieldIndex = catalystFieldNameToIndex(parquetField.getName) - val catalystField = catalystType(fieldIndex) + val catalystFieldIndex = Option(parquetField.getId).flatMap { fieldId => + // field has id, try to match by id first before falling back to match by name + catalystFieldIdxByFieldId.get(fieldId.intValue()) + }.getOrElse { + // field doesn't have id, just match by name + catalystFieldIdxByName(parquetField.getName) + } + val catalystField = catalystType(catalystFieldIndex) // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` - newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + newConverter(parquetField, + catalystField.dataType, new RowUpdater(currentRow, catalystFieldIndex)) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cb5d646f85e9e..34a4eb8c002d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -434,20 +434,25 @@ class ParquetToSparkSchemaConverter( * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. * @param outputTimestampType which parquet timestamp type to use when writing. + * @param useFieldId whether we should include write field id to Parquet schema. Set this to false + * via `spark.sql.parquet.fieldId.write.enabled = false` to disable writing field ids. */ class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = - SQLConf.ParquetOutputTimestampType.INT96) { + SQLConf.ParquetOutputTimestampType.INT96, + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - outputTimestampType = conf.parquetOutputTimestampType) + outputTimestampType = conf.parquetOutputTimestampType, + useFieldId = conf.parquetFieldIdWriteEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( - conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), + useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -463,7 +468,12 @@ class SparkToParquetSchemaConverter( * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. */ def convertField(field: StructField): Type = { - convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + val converted = convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + if (useFieldId && ParquetUtils.hasFieldId(field)) { + converted.withId(ParquetUtils.getFieldId(field)) + } else { + converted + } } private def convertField(field: StructField, repetition: Type.Repetition): Type = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 63c529e3542f2..2c565c8890e70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} object ParquetUtils { def inferSchema( @@ -145,6 +145,48 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field + * ID metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require(hasFieldId(field), + s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) + try { + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + } catch { + case _: ArithmeticException | _: ClassCastException => + throw new IllegalArgumentException( + s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") + } + } + /** * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index 0316d91f40732..d84acedb962e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -81,6 +81,10 @@ case class ParquetWrite( conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sqlConf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sqlConf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala new file mode 100644 index 0000000000000..ff0bb2f92d208 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructType} + +class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkSession { + + private def withId(id: Int): Metadata = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + test("Parquet reads infer fields using field ids correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixed = + new StructType() + .add("name", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixedHalfMatched = + new StructType() + .add("unmatched", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("random", IntegerType, true, withId(1)) + .add("name", StringType, true, withId(0)) + + val readData = Seq(Row("text", 100), Row("more", 200)) + val readDataHalfMatched = Seq(Row(null, 100), Row(null, 200)) + val writeData = Seq(Row(100, "text"), Row(200, "more")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + // read with schema + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("b < 50"), Seq.empty) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("a >= 'oh'"), Row("text", 100) :: Nil) + // read with mixed field-id/name schema + checkAnswer(spark.read.schema(readSchemaMixed).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchemaMixedHalfMatched) + .parquet(dir.getCanonicalPath), readDataHalfMatched) + + // schema inference should pull into the schema with ids + val reader = spark.read.parquet(dir.getCanonicalPath) + assert(reader.schema == writeSchema) + checkAnswer(reader.where("name >= 'oh'"), Row(100, "text") :: Nil) + } + } + } + + test("absence of field ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", StringType, true, withId(2)) + .add("c", IntegerType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(3)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // 3 different cases for the 3 columns to read: + // - a: ID 1 is not found, but there is column with name `a`, still return null + // - b: ID 2 is not found, return null + // - c: ID 3 is found, read it + Row(null, null, 100) :: Row(null, null, 200) :: Nil) + } + } + } + + test("multiple id matches") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(1)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Found duplicate field(s)")) + } + } + } + + test("read parquet file without ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true) + .add("rand1", StringType, true) + .add("rand2", StringType, true) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => + val cause = intercept[SparkException] { + spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + val expectedValues = (1 to schema.length).map(_ => null) + withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { + checkAnswer( + spark.read.schema(schema).parquet(dir.getCanonicalPath), + Row(expectedValues: _*) :: Row(expectedValues: _*) :: Nil) + } + } + } + } + } + + test("global read/write flag should work correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("some", IntegerType, true, withId(1)) + .add("other", StringType, true, withId(2)) + .add("name", StringType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(3)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + val expectedResult = Seq(Row(null, null, null), Row(null, null, null)) + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "false", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // no field id found exception + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + } + } + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "true", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "false") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // ids are there, but we don't use id for matching, so no results would be returned + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), expectedResult) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala new file mode 100644 index 0000000000000..b3babdd3a0cff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { + + private val FAKE_COLUMN_NAME = "_fake_name_" + private val UUID_REGEX = + "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + + private def withId(id: Int) = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String, + caseSensitive: Boolean = true, + useFieldId: Boolean = true): Unit = { + test(s"Clipping with field id - $testName") { + val fileSchema = MessageTypeParser.parseMessageType(parquetSchema) + val actual = ParquetReadSupport.clipParquetSchema( + fileSchema, + catalystSchema, + caseSensitive = caseSensitive, + useFieldId = useFieldId) + + // each fake name should be uniquely generated + val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) + assert( + fakeColumnNames.distinct == fakeColumnNames, "Should generate unique fake column names") + + // replace the random part of all fake names with a fixed id generator + val ids1 = (1 to 100).iterator + val actualNormalized = MessageTypeParser.parseMessageType( + UUID_REGEX.replaceAllIn(actual.toString, _ => ids1.next().toString) + ) + val ids2 = (1 to 100).iterator + val expectedNormalized = MessageTypeParser.parseMessageType( + FAKE_COLUMN_NAME.r.replaceAllIn(expectedSchema, _ => s"$FAKE_COLUMN_NAME${ids2.next()}") + ) + + try { + expectedNormalized.checkContains(actualNormalized) + actualNormalized.checkContains(expectedNormalized) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expectedSchema + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + checkEqual(actualNormalized, expectedNormalized) + // might be redundant but just to have some free tests for the utils + assert(ParquetReadSupport.containsFieldIds(fileSchema)) + assert(ParquetUtils.hasFieldIds(catalystSchema)) + } + } + + private def testSqlToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String): Unit = { + val converter = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, + outputTimestampType = SQLConf.ParquetOutputTimestampType.INT96, + useFieldId = true) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + checkEqual(actual, expected) + } + } + + private def checkEqual(actual: MessageType, expected: MessageType): Unit = { + actual.checkContains(expected) + expected.checkContains(actual) + assert(actual.toString == expected.toString, + s""" + |Schema mismatch. + |Expected schema: + |${expected.toString} + |Actual schema: + |${actual.toString} + """.stripMargin + ) + } + + test("check hasFieldIds for schema") { + val simpleSchemaMissingId = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true) + + assert(ParquetUtils.hasFieldIds(simpleSchemaMissingId)) + + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(8)) + + assert(ParquetUtils.hasFieldIds(f01ElementType)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + assert(ParquetUtils.hasFieldIds(f0Type)) + + assert(ParquetUtils.hasFieldIds( + new StructType().add("f0", f0Type, nullable = false, withId(1)))) + + assert(!ParquetUtils.hasFieldIds(new StructType().add("f0", IntegerType, nullable = true))) + assert(!ParquetUtils.hasFieldIds(new StructType())); + } + + test("check getFieldId for schema") { + val schema = new StructType() + .add("overflowId", DoubleType, nullable = true, + new MetadataBuilder() + .putLong(ParquetUtils.FIELD_ID_METADATA_KEY, 12345678987654321L).build()) + .add("stringId", StringType, nullable = true, + new MetadataBuilder() + .putString(ParquetUtils.FIELD_ID_METADATA_KEY, "lol").build()) + .add("negativeId", LongType, nullable = true, withId(-20)) + .add("noId", LongType, nullable = true) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("noId")).get._2) + }.getMessage.contains("doesn't exist")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("overflowId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("stringId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + // negative id allowed + assert(ParquetUtils.getFieldId(schema.findNestedField(Seq("negativeId")).get._2) == -20) + } + + test("check containsFieldIds for parquet schema") { + + // empty Parquet schema fails too + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 = 1 { + | optional int32 f00; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00 = 1; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list = 1 { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + } + + test("ID in Parquet Types is read as null when not set") { + val parquetSchemaString = + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin + + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaString) + val f0 = parquetSchema.getFields().get(0) + assert(f0.getId() == null) + assert(f0.asGroupType().getFields.get(0).getId == null) + } + + testSqlToParquet( + "standard array", + sqlSchema = { + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("f0", f0Type, nullable = false, withId(1)) + }, + parquetSchema = + """message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f010 = 7; + | optional int64 f012 = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 f01 = 3; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add( + "g00", IntegerType, nullable = true, withId(2)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(4)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 $FAKE_COLUMN_NAME = 4; + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional int32 f010 = 7; + | optional double f011 = 8; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("g011", DoubleType, nullable = true, withId(8)) + .add("g012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("g00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("g01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("g0", f0Type, nullable = false, withId(1)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f011 = 8; + | optional int64 $FAKE_COLUMN_NAME = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int32 value_f0 = 4; + | required int64 value_f1 = 6; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_g1", LongType, nullable = false, withId(6)) + .add("value_g2", DoubleType, nullable = false, withId(7)) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("g0", f0Type, nullable = false, withId(3)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int64 value_f1 = 6; + | required double $FAKE_COLUMN_NAME = 7; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "won't match field id if structure is different", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + // parquet has id 3, but won't use because structure is different + .add("g01", IntegerType, nullable = true, withId(3)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + }, + + // note that f1 is not picked up, even though it's Id is 3 + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 $FAKE_COLUMN_NAME = 3; + | } + |} + """.stripMargin) + + testSchemaClipping( + "Complex type with multiple mismatches should work", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(999)) + .add("g1", IntegerType, nullable = true, withId(3)) + .add("g2", IntegerType, nullable = true, withId(888)) + }, + + expectedSchema = + s"""message spark_schema { + | required group $FAKE_COLUMN_NAME = 999 { + | optional int32 g00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 $FAKE_COLUMN_NAME = 888; + |} + """.stripMargin) + + testSchemaClipping( + "Should allow fall-back to name matching if id not found", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + // nested f00 without id should also work + .add("f00", IntegerType, nullable = true) + + val f4Type = new StructType() + .add("g40", IntegerType, nullable = true, withId(6)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(3)) + // f2 without id should be matched using name matching + .add("f2", IntegerType, nullable = true) + // name is not matched + .add("g2", IntegerType, nullable = true) + // f4 without id will do name matching, but g40 will be matched using id + .add("f4", f4Type, nullable = true) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | optional int32 g2; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 272f12e138b68..2feea41d15656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -2257,7 +2257,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive, + useFieldId = false) try { expectedSchema.checkContains(actual) @@ -2821,7 +2824,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } assertThrows[RuntimeException] { ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive = false, + useFieldId = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 7a7957c67dce1..18690844d484c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -165,9 +165,17 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { def withAllParquetReaders(code: => Unit): Unit = { // test the row-based reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withClue("Parquet-mr reader") { + code + } + } // test the vectorized reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withClue("Vectorized reader") { + code + } + } } def withAllParquetWriters(code: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 47a6f3617da63..fb3d38f3b7b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,7 +61,13 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5") + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // Enable parquet read field id for tests to ensure correctness + // By default, if Spark schema doesn't contain the `parquet.field.id` metadata, + // the underlying matching mechanism should behave exactly like name matching + // which is the existing behavior. Therefore, turning this on ensures that we didn't + // introduce any regression for such mixed matching mode. + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") } private[sql] class TestSQLSessionStateBuilder(