diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 76eb4311e41bf..d3d4ec3dab1bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -805,6 +805,15 @@ object QueryExecutionErrors { """.stripMargin.replaceAll("\n", " ")) } + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): Throwable = { + new RuntimeException( + s""" + |Found duplicate field(s) "$requiredId": $matchedFields + |in id mapping mode + """.stripMargin.replaceAll("\n", " ")) + } + def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { new SparkException(s"Failed to merge incompatible schemas $left and $right", e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 42979a68d8578..9e8d01a78c9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,6 +934,33 @@ object SQLConf { .intConf .createWithDefault(4096) + val PARQUET_FIELD_ID_WRITE_ENABLED = + buildConf("spark.sql.parquet.fieldId.write.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, " + + "Parquet writers will populate the field Id " + + "metadata (if present) in the Spark schema to the Parquet schema.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_FIELD_ID_READ_ENABLED = + buildConf("spark.sql.parquet.fieldId.read.enabled") + .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " + + "will use field IDs (if present) in the requested Spark schema to look up Parquet " + + "fields instead of using column names") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val IGNORE_MISSING_PARQUET_FIELD_ID = + buildConf("spark.sql.parquet.fieldId.read.ignoreMissing") + .doc("When the Parquet file doesn't have any field IDs but the " + + "Spark read schema is using field IDs to read, we will silently return nulls " + + "when this flag is enabled, or error otherwise.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -4251,6 +4278,12 @@ class SQLConf extends Serializable with Logging { def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) + + def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED) + + def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) + def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index aa6f9ee91656d..18876dedb951e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -119,6 +119,10 @@ class ParquetFileFormat SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index bdab0f7892f00..97e691ff7c66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.time.ZoneId -import java.util.{Locale, Map => JMap} +import java.util +import java.util.{Locale, Map => JMap, UUID} import scala.collection.JavaConverters._ @@ -85,13 +86,71 @@ class ParquetReadSupport( StructType.fromString(schemaString) } + val parquetRequestedSchema = ParquetReadSupport.getRequestedSchema( + context.getFileSchema, catalystRequestedSchema, conf, enableVectorizedReader) + new ReadContext(parquetRequestedSchema, new util.HashMap[String, String]()) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new ParquetRecordMaterializer( + parquetRequestedSchema, + ParquetReadSupport.expandUDT(catalystRequestedSchema), + new ParquetToSparkSchemaConverter(conf), + convertTz, + datetimeRebaseSpec, + int96RebaseSpec) + } +} + +object ParquetReadSupport extends Logging { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + def generateFakeColumnName: String = s"_fake_name_${UUID.randomUUID()}" + + def getRequestedSchema( + parquetFileSchema: MessageType, + catalystRequestedSchema: StructType, + conf: Configuration, + enableVectorizedReader: Boolean): MessageType = { val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) - val parquetFileSchema = context.getFileSchema + val useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) + val ignoreMissingIds = conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key, + SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get) + + if (!ignoreMissingIds && + !containsFieldIds(parquetFileSchema) && + ParquetUtils.hasFieldIds(catalystRequestedSchema)) { + throw new RuntimeException( + "Spark read schema expects field Ids, " + + "but Parquet file schema doesn't contain any field Ids.\n" + + "Please remove the field ids from Spark schema or ignore missing ids by " + + "setting `spark.sql.parquet.fieldId.ignoreMissing = true`\n" + + s""" + |Spark read schema: + |${catalystRequestedSchema.prettyJson} + | + |Parquet file schema: + |${parquetFileSchema.toString} + |""".stripMargin) + } val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, - catalystRequestedSchema, caseSensitive) + catalystRequestedSchema, caseSensitive, useFieldId) // We pass two schema to ParquetRecordMaterializer: // - parquetRequestedSchema: the schema of the file data we want to read @@ -109,6 +168,7 @@ class ParquetReadSupport( // in parquetRequestedSchema which are not present in the file. parquetClippedSchema } + logDebug( s"""Going to read the following fields from the Parquet file with the following schema: |Parquet file schema: @@ -120,34 +180,20 @@ class ParquetReadSupport( |Catalyst requested schema: |${catalystRequestedSchema.treeString} """.stripMargin) - new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + + parquetRequestedSchema } /** - * Called on executor side after [[init()]], before instantiating actual Parquet record readers. - * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet - * records to Catalyst [[InternalRow]]s. + * Overloaded method for backward compatibility with + * `caseSensitive` default to `true` and `useFieldId` default to `false` */ - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - val parquetRequestedSchema = readContext.getRequestedSchema - new ParquetRecordMaterializer( - parquetRequestedSchema, - ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf), - convertTz, - datetimeRebaseSpec, - int96RebaseSpec) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + clipParquetSchema(parquetSchema, catalystSchema, caseSensitive, useFieldId = false) } -} - -object ParquetReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" /** * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist @@ -156,9 +202,10 @@ object ParquetReadSupport { def clipParquetSchema( parquetSchema: MessageType, catalystSchema: StructType, - caseSensitive: Boolean = true): MessageType = { + caseSensitive: Boolean, + useFieldId: Boolean): MessageType = { val clippedParquetFields = clipParquetGroupFields( - parquetSchema.asGroupType(), catalystSchema, caseSensitive) + parquetSchema.asGroupType(), catalystSchema, caseSensitive, useFieldId) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -170,26 +217,36 @@ object ParquetReadSupport { } private def clipParquetType( - parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { - catalystType match { + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { + val newParquetType = catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive, useFieldId) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + clipParquetMapType( + parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } + + if (useFieldId && parquetType.getId != null) { + newParquetType.withId(parquetType.getId.intValue()) + } else { + newParquetType + } } /** @@ -210,7 +267,10 @@ object ParquetReadSupport { * [[StructType]]. */ private def clipParquetListType( - parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean, + useFieldId: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) @@ -218,7 +278,7 @@ object ParquetReadSupport { // list element type is just the group itself. Clip it. if (parquetList.getLogicalTypeAnnotation == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType, caseSensitive) + clipParquetType(parquetList, elementType, caseSensitive, useFieldId) } else { assert( parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], @@ -250,19 +310,28 @@ object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive, useFieldId)) .named(parquetList.getName) } else { + val newRepeatedGroup = Types + .repeatedGroup() + .addField( + clipParquetType( + repeatedGroup.getType(0), elementType, caseSensitive, useFieldId)) + .named(repeatedGroup.getName) + + val newElementType = if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + // Otherwise, the repeated field's type is the element type with the repeated field's // repetition. Types .buildGroup(parquetList.getRepetition) .as(LogicalTypeAnnotation.listType()) - .addField( - Types - .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) - .named(repeatedGroup.getName)) + .addField(newElementType) .named(parquetList.getName) } } @@ -277,7 +346,8 @@ object ParquetReadSupport { parquetMap: GroupType, keyType: DataType, valueType: DataType, - caseSensitive: Boolean): GroupType = { + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -285,13 +355,19 @@ object ParquetReadSupport { val parquetKeyType = repeatedGroup.getType(0) val parquetValueType = repeatedGroup.getType(1) - val clippedRepeatedGroup = - Types + val clippedRepeatedGroup = { + val newRepeatedGroup = Types .repeatedGroup() .as(repeatedGroup.getLogicalTypeAnnotation) - .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) - .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive, useFieldId)) .named(repeatedGroup.getName) + if (useFieldId && repeatedGroup.getId != null) { + newRepeatedGroup.withId(repeatedGroup.getId.intValue()) + } else { + newRepeatedGroup + } + } Types .buildGroup(parquetMap.getRepetition) @@ -309,8 +385,12 @@ object ParquetReadSupport { * pruning. */ private def clipParquetGroup( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): GroupType = { + val clippedParquetFields = + clipParquetGroupFields(parquetRecord, structType, caseSensitive, useFieldId) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getLogicalTypeAnnotation) @@ -324,23 +404,29 @@ object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { - val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - if (caseSensitive) { - val caseSensitiveParquetFieldMap = + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean, + useFieldId: Boolean): Seq[Type] = { + val toParquet = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, useFieldId = useFieldId) + lazy val caseSensitiveParquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - structType.map { f => - caseSensitiveParquetFieldMap + lazy val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + lazy val idToParquetFieldMap = + parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f => f.getId.intValue()) + + def matchCaseSensitiveField(f: StructField): Type = { + caseSensitiveParquetFieldMap .get(f.name) - .map(clipParquetType(_, f.dataType, caseSensitive)) + .map(clipParquetType(_, f.dataType, caseSensitive, useFieldId)) .getOrElse(toParquet.convertField(f)) - } - } else { + } + + def matchCaseInsensitiveField(f: StructField): Type = { // Do case-insensitive resolution only if in case-insensitive mode - val caseInsensitiveParquetFieldMap = - parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) - structType.map { f => - caseInsensitiveParquetFieldMap + caseInsensitiveParquetFieldMap .get(f.name.toLowerCase(Locale.ROOT)) .map { parquetTypes => if (parquetTypes.size > 1) { @@ -349,9 +435,39 @@ object ParquetReadSupport { throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError( f.name, parquetTypesString) } else { - clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) } }.getOrElse(toParquet.convertField(f)) + } + + def matchIdField(f: StructField): Type = { + val fieldId = ParquetUtils.getFieldId(f) + idToParquetFieldMap + .get(fieldId) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError( + fieldId, parquetTypesString) + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive, useFieldId) + } + }.getOrElse { + // When there is no ID match, we use a fake name to avoid a name match by accident + // We need this name to be unique as well, otherwise there will be type conflicts + toParquet.convertField(f.copy(name = generateFakeColumnName)) + } + } + + val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType) + structType.map { f => + if (shouldMatchById && ParquetUtils.hasFieldId(f)) { + matchIdField(f) + } else if (caseSensitive) { + matchCaseSensitiveField(f) + } else { + matchCaseInsensitiveField(f) } } } @@ -410,4 +526,13 @@ object ParquetReadSupport { expand(schema).asInstanceOf[StructType] } + + /** + * Whether the parquet schema contains any field IDs. + */ + def containsFieldIds(schema: Type): Boolean = schema match { + case p: PrimitiveType => p.getId != null + // We don't require all fields to have IDs, so we use `exists` here. + case g: GroupType => g.getId != null || g.getFields.asScala.exists(containsFieldIds) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b12898360dcf4..63ad5ed6db82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -203,16 +203,38 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { - catalystType.fieldNames.zipWithIndex.toMap + def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap + + val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { + nameToIndex } else { - CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + CaseInsensitiveMap(nameToIndex) } + + // (SPARK-38094) parquet field ids, if exist, should be prioritized for matching + val catalystFieldIdxByFieldId = + if (SQLConf.get.parquetFieldIdReadEnabled && ParquetUtils.hasFieldIds(catalystType)) { + catalystType.fields + .zipWithIndex + .filter { case (f, _) => ParquetUtils.hasFieldId(f) } + .map { case (f, idx) => (ParquetUtils.getFieldId(f), idx) } + .toMap + } else { + Map.empty[Int, Int] + } + parquetType.getFields.asScala.map { parquetField => - val fieldIndex = catalystFieldNameToIndex(parquetField.getName) - val catalystField = catalystType(fieldIndex) + val catalystFieldIndex = Option(parquetField.getId).flatMap { fieldId => + // field has id, try to match by id first before falling back to match by name + catalystFieldIdxByFieldId.get(fieldId.intValue()) + }.getOrElse { + // field doesn't have id, just match by name + catalystFieldIdxByName(parquetField.getName) + } + val catalystField = catalystType(catalystFieldIndex) // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` - newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + newConverter(parquetField, + catalystField.dataType, new RowUpdater(currentRow, catalystFieldIndex)) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cb5d646f85e9e..34a4eb8c002d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -434,20 +434,25 @@ class ParquetToSparkSchemaConverter( * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. * @param outputTimestampType which parquet timestamp type to use when writing. + * @param useFieldId whether we should include write field id to Parquet schema. Set this to false + * via `spark.sql.parquet.fieldId.write.enabled = false` to disable writing field ids. */ class SparkToParquetSchemaConverter( writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = - SQLConf.ParquetOutputTimestampType.INT96) { + SQLConf.ParquetOutputTimestampType.INT96, + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) { def this(conf: SQLConf) = this( writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - outputTimestampType = conf.parquetOutputTimestampType) + outputTimestampType = conf.parquetOutputTimestampType, + useFieldId = conf.parquetFieldIdWriteEnabled) def this(conf: Configuration) = this( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( - conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)), + useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -463,7 +468,12 @@ class SparkToParquetSchemaConverter( * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. */ def convertField(field: StructField): Type = { - convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + val converted = convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + if (useFieldId && ParquetUtils.hasFieldId(field)) { + converted.withId(ParquetUtils.getFieldId(field)) + } else { + converted + } } private def convertField(field: StructField, repetition: Type.Repetition): Type = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 87a0d9c860f31..812fa2224d284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} object ParquetUtils { def inferSchema( @@ -144,6 +144,48 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * A StructField metadata key used to set the field id of a column in the Parquet schema. + */ + val FIELD_ID_METADATA_KEY = "parquet.field.id" + + /** + * Whether there exists a field in the schema, whether inner or leaf, has the parquet field + * ID metadata. + */ + def hasFieldIds(schema: StructType): Boolean = { + def recursiveCheck(schema: DataType): Boolean = { + schema match { + case st: StructType => + st.exists(field => hasFieldId(field) || recursiveCheck(field.dataType)) + + case at: ArrayType => recursiveCheck(at.elementType) + + case mt: MapType => recursiveCheck(mt.keyType) || recursiveCheck(mt.valueType) + + case _ => + // No need to really check primitive types, just to terminate the recursion + false + } + } + if (schema.isEmpty) false else recursiveCheck(schema) + } + + def hasFieldId(field: StructField): Boolean = + field.metadata.contains(FIELD_ID_METADATA_KEY) + + def getFieldId(field: StructField): Int = { + require(hasFieldId(field), + s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + field) + try { + Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY)) + } catch { + case _: ArithmeticException | _: ClassCastException => + throw new IllegalArgumentException( + s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer") + } + } + /** * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index 0316d91f40732..d84acedb962e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -81,6 +81,10 @@ case class ParquetWrite( conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, sqlConf.parquetOutputTimestampType.toString) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sqlConf.parquetFieldIdWriteEnabled.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala new file mode 100644 index 0000000000000..ff0bb2f92d208 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdIOSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructType} + +class ParquetFieldIdIOSuite extends QueryTest with ParquetTest with SharedSparkSession { + + private def withId(id: Int): Metadata = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + test("Parquet reads infer fields using field ids correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", StringType, true, withId(0)) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixed = + new StructType() + .add("name", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val readSchemaMixedHalfMatched = + new StructType() + .add("unmatched", StringType, true) + .add("b", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("random", IntegerType, true, withId(1)) + .add("name", StringType, true, withId(0)) + + val readData = Seq(Row("text", 100), Row("more", 200)) + val readDataHalfMatched = Seq(Row(null, 100), Row(null, 200)) + val writeData = Seq(Row(100, "text"), Row(200, "more")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + // read with schema + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("b < 50"), Seq.empty) + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath) + .where("a >= 'oh'"), Row("text", 100) :: Nil) + // read with mixed field-id/name schema + checkAnswer(spark.read.schema(readSchemaMixed).parquet(dir.getCanonicalPath), readData) + checkAnswer(spark.read.schema(readSchemaMixedHalfMatched) + .parquet(dir.getCanonicalPath), readDataHalfMatched) + + // schema inference should pull into the schema with ids + val reader = spark.read.parquet(dir.getCanonicalPath) + assert(reader.schema == writeSchema) + checkAnswer(reader.where("name >= 'oh'"), Row(100, "text") :: Nil) + } + } + } + + test("absence of field ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("b", StringType, true, withId(2)) + .add("c", IntegerType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(3)) + .add("randomName", StringType, true) + + val writeData = Seq(Row(100, "text"), Row(200, "more")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), + // 3 different cases for the 3 columns to read: + // - a: ID 1 is not found, but there is column with name `a`, still return null + // - b: ID 2 is not found, return null + // - c: ID 3 is found, read it + Row(null, null, 100) :: Row(null, null, 200) :: Nil) + } + } + } + + test("multiple id matches") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(1)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + + withAllParquetReaders { + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Found duplicate field(s)")) + } + } + } + + test("read parquet file without ids") { + withTempDir { dir => + val readSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true) + .add("rand1", StringType, true) + .add("rand2", StringType, true) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + Seq(readSchema, readSchema.add("b", StringType, true)).foreach { schema => + val cause = intercept[SparkException] { + spark.read.schema(schema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + val expectedValues = (1 to schema.length).map(_ => null) + withSQLConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key -> "true") { + checkAnswer( + spark.read.schema(schema).parquet(dir.getCanonicalPath), + Row(expectedValues: _*) :: Row(expectedValues: _*) :: Nil) + } + } + } + } + } + + test("global read/write flag should work correctly") { + withTempDir { dir => + val readSchema = + new StructType() + .add("some", IntegerType, true, withId(1)) + .add("other", StringType, true, withId(2)) + .add("name", StringType, true, withId(3)) + + val writeSchema = + new StructType() + .add("a", IntegerType, true, withId(1)) + .add("rand1", StringType, true, withId(2)) + .add("rand2", StringType, true, withId(3)) + + val writeData = Seq(Row(100, "text", "txt"), Row(200, "more", "mr")) + + val expectedResult = Seq(Row(null, null, null), Row(null, null, null)) + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "false", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // no field id found exception + val cause = intercept[SparkException] { + spark.read.schema(readSchema).parquet(dir.getCanonicalPath).collect() + }.getCause + assert(cause.isInstanceOf[RuntimeException] && + cause.getMessage.contains("Parquet file schema doesn't contain any field Ids")) + } + } + + withSQLConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key -> "true", + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "false") { + spark.createDataFrame(writeData.asJava, writeSchema) + .write.mode("overwrite").parquet(dir.getCanonicalPath) + withAllParquetReaders { + // ids are there, but we don't use id for matching, so no results would be returned + checkAnswer(spark.read.schema(readSchema).parquet(dir.getCanonicalPath), expectedResult) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala new file mode 100644 index 0000000000000..b3babdd3a0cff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class ParquetFieldIdSchemaSuite extends ParquetSchemaTest { + + private val FAKE_COLUMN_NAME = "_fake_name_" + private val UUID_REGEX = + "[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}".r + + private def withId(id: Int) = + new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build() + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String, + caseSensitive: Boolean = true, + useFieldId: Boolean = true): Unit = { + test(s"Clipping with field id - $testName") { + val fileSchema = MessageTypeParser.parseMessageType(parquetSchema) + val actual = ParquetReadSupport.clipParquetSchema( + fileSchema, + catalystSchema, + caseSensitive = caseSensitive, + useFieldId = useFieldId) + + // each fake name should be uniquely generated + val fakeColumnNames = actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME)) + assert( + fakeColumnNames.distinct == fakeColumnNames, "Should generate unique fake column names") + + // replace the random part of all fake names with a fixed id generator + val ids1 = (1 to 100).iterator + val actualNormalized = MessageTypeParser.parseMessageType( + UUID_REGEX.replaceAllIn(actual.toString, _ => ids1.next().toString) + ) + val ids2 = (1 to 100).iterator + val expectedNormalized = MessageTypeParser.parseMessageType( + FAKE_COLUMN_NAME.r.replaceAllIn(expectedSchema, _ => s"$FAKE_COLUMN_NAME${ids2.next()}") + ) + + try { + expectedNormalized.checkContains(actualNormalized) + actualNormalized.checkContains(expectedNormalized) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expectedSchema + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + checkEqual(actualNormalized, expectedNormalized) + // might be redundant but just to have some free tests for the utils + assert(ParquetReadSupport.containsFieldIds(fileSchema)) + assert(ParquetUtils.hasFieldIds(catalystSchema)) + } + } + + private def testSqlToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String): Unit = { + val converter = new SparkToParquetSchemaConverter( + writeLegacyParquetFormat = false, + outputTimestampType = SQLConf.ParquetOutputTimestampType.INT96, + useFieldId = true) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + checkEqual(actual, expected) + } + } + + private def checkEqual(actual: MessageType, expected: MessageType): Unit = { + actual.checkContains(expected) + expected.checkContains(actual) + assert(actual.toString == expected.toString, + s""" + |Schema mismatch. + |Expected schema: + |${expected.toString} + |Actual schema: + |${actual.toString} + """.stripMargin + ) + } + + test("check hasFieldIds for schema") { + val simpleSchemaMissingId = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true) + + assert(ParquetUtils.hasFieldIds(simpleSchemaMissingId)) + + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(8)) + + assert(ParquetUtils.hasFieldIds(f01ElementType)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + assert(ParquetUtils.hasFieldIds(f0Type)) + + assert(ParquetUtils.hasFieldIds( + new StructType().add("f0", f0Type, nullable = false, withId(1)))) + + assert(!ParquetUtils.hasFieldIds(new StructType().add("f0", IntegerType, nullable = true))) + assert(!ParquetUtils.hasFieldIds(new StructType())); + } + + test("check getFieldId for schema") { + val schema = new StructType() + .add("overflowId", DoubleType, nullable = true, + new MetadataBuilder() + .putLong(ParquetUtils.FIELD_ID_METADATA_KEY, 12345678987654321L).build()) + .add("stringId", StringType, nullable = true, + new MetadataBuilder() + .putString(ParquetUtils.FIELD_ID_METADATA_KEY, "lol").build()) + .add("negativeId", LongType, nullable = true, withId(-20)) + .add("noId", LongType, nullable = true) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("noId")).get._2) + }.getMessage.contains("doesn't exist")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("overflowId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + assert(intercept[IllegalArgumentException] { + ParquetUtils.getFieldId(schema.findNestedField(Seq("stringId")).get._2) + }.getMessage.contains("must be a 32-bit integer")) + + // negative id allowed + assert(ParquetUtils.getFieldId(schema.findNestedField(Seq("negativeId")).get._2) == -20) + } + + test("check containsFieldIds for parquet schema") { + + // empty Parquet schema fails too + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 = 1 { + | optional int32 f00; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message root { + | required group f0 { + | optional int32 f00 = 1; + | optional binary f01; + | } + |} + """.stripMargin))) + + assert( + !ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + + assert( + ParquetReadSupport.containsFieldIds( + MessageTypeParser.parseMessageType( + """message spark_schema { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list = 1 { + | required binary element (UTF8); + | } + | } + | } + |} + """.stripMargin))) + } + + test("ID in Parquet Types is read as null when not set") { + val parquetSchemaString = + """message root { + | required group f0 { + | optional int32 f00; + | } + |} + """.stripMargin + + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaString) + val f0 = parquetSchema.getFields().get(0) + assert(f0.getId() == null) + assert(f0.asGroupType().getFields.get(0).getId == null) + } + + testSqlToParquet( + "standard array", + sqlSchema = { + val f01ElementType = new StructType() + .add("f010", DoubleType, nullable = true, withId(7)) + .add("f012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("f0", f0Type, nullable = false, withId(1)) + }, + parquetSchema = + """message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f010 = 7; + | optional int64 f012 = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 f01 = 3; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add( + "g00", IntegerType, nullable = true, withId(2)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(4)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 $FAKE_COLUMN_NAME = 4; + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional int32 f010 = 7; + | optional double f011 = 8; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("g011", DoubleType, nullable = true, withId(8)) + .add("g012", LongType, nullable = true, withId(9)) + + val f0Type = new StructType() + .add("g00", ArrayType(StringType, containsNull = false), nullable = true, withId(2)) + .add("g01", ArrayType(f01ElementType, containsNull = false), nullable = true, withId(5)) + + new StructType().add("g0", f0Type, nullable = false, withId(1)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional group f00 (LIST) = 2 { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) = 5 { + | repeated group list { + | required group element { + | optional double f011 = 8; + | optional int64 $FAKE_COLUMN_NAME = 9; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int32 value_f0 = 4; + | required int64 value_f1 = 6; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_g1", LongType, nullable = false, withId(6)) + .add("value_g2", DoubleType, nullable = false, withId(7)) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("g0", f0Type, nullable = false, withId(3)) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 (MAP) = 3 { + | repeated group key_value = 1 { + | required group key = 2 { + | required int64 value_f1 = 6; + | required double $FAKE_COLUMN_NAME = 7; + | } + | required int32 value = 5; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "won't match field id if structure is different", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + // parquet has id 3, but won't use because structure is different + .add("g01", IntegerType, nullable = true, withId(3)) + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + }, + + // note that f1 is not picked up, even though it's Id is 3 + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | optional int32 $FAKE_COLUMN_NAME = 3; + | } + |} + """.stripMargin) + + testSchemaClipping( + "Complex type with multiple mismatches should work", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + .add("g00", IntegerType, nullable = true, withId(2)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(999)) + .add("g1", IntegerType, nullable = true, withId(3)) + .add("g2", IntegerType, nullable = true, withId(888)) + }, + + expectedSchema = + s"""message spark_schema { + | required group $FAKE_COLUMN_NAME = 999 { + | optional int32 g00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 $FAKE_COLUMN_NAME = 888; + |} + """.stripMargin) + + testSchemaClipping( + "Should allow fall-back to name matching if id not found", + + parquetSchema = + """message root { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType() + // nested f00 without id should also work + .add("f00", IntegerType, nullable = true) + + val f4Type = new StructType() + .add("g40", IntegerType, nullable = true, withId(6)) + + new StructType() + .add("g0", f0Type, nullable = false, withId(1)) + .add("g1", IntegerType, nullable = true, withId(3)) + // f2 without id should be matched using name matching + .add("f2", IntegerType, nullable = true) + // name is not matched + .add("g2", IntegerType, nullable = true) + // f4 without id will do name matching, but g40 will be matched using id + .add("f4", f4Type, nullable = true) + }, + + expectedSchema = + s"""message spark_schema { + | required group f0 = 1 { + | optional int32 f00 = 2; + | } + | optional int32 f1 = 3; + | optional int32 f2 = 4; + | optional int32 g2; + | required group f4 = 5 { + | optional int32 f40 = 6; + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 272f12e138b68..2feea41d15656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -2257,7 +2257,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive, + useFieldId = false) try { expectedSchema.checkContains(actual) @@ -2821,7 +2824,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } assertThrows[RuntimeException] { ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + MessageTypeParser.parseMessageType(parquetSchema), + catalystSchema, + caseSensitive = false, + useFieldId = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 7a7957c67dce1..18690844d484c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -165,9 +165,17 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { def withAllParquetReaders(code: => Unit): Unit = { // test the row-based reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withClue("Parquet-mr reader") { + code + } + } // test the vectorized reader - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true")(code) + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withClue("Vectorized reader") { + code + } + } } def withAllParquetWriters(code: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 47a6f3617da63..fb3d38f3b7b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -61,7 +61,13 @@ private[sql] object TestSQLContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5") + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // Enable parquet read field id for tests to ensure correctness + // By default, if Spark schema doesn't contain the `parquet.field.id` metadata, + // the underlying matching mechanism should behave exactly like name matching + // which is the existing behavior. Therefore, turning this on ensures that we didn't + // introduce any regression for such mixed matching mode. + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key -> "true") } private[sql] class TestSQLSessionStateBuilder(