From dc8523b1965ba074ef2644dc0627a1e0bf736a61 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Wed, 7 Aug 2024 21:34:34 -0700 Subject: [PATCH 1/8] [SEDONA-637] Refactor multiple spark version build and package --- pom.xml | 52 +- spark/pom.xml | 4 + spark/spark-3.1/.gitignore | 12 + spark/spark-3.1/pom.xml | 145 ++++ ...pache.spark.sql.sources.DataSourceRegister | 2 + .../parquet/GeoDataSourceUtils.scala | 147 ++++ .../parquet/GeoDateTimeUtils.scala | 43 + .../parquet/GeoParquetFileFormat.scala | 437 ++++++++++ .../parquet/GeoParquetFilters.scala | 678 ++++++++++++++++ .../parquet/GeoParquetReadSupport.scala | 418 ++++++++++ .../GeoParquetRecordMaterializer.scala | 69 ++ .../parquet/GeoParquetRowConverter.scala | 745 +++++++++++++++++ .../parquet/GeoParquetSchemaConverter.scala | 601 ++++++++++++++ .../datasources/parquet/GeoParquetUtils.scala | 127 +++ .../parquet/GeoParquetWriteSupport.scala | 628 +++++++++++++++ .../parquet/GeoSchemaMergeUtils.scala | 107 +++ .../GeoParquetMetadataDataSource.scala | 65 ++ ...arquetMetadataPartitionReaderFactory.scala | 118 +++ .../metadata/GeoParquetMetadataScan.scala | 69 ++ .../GeoParquetMetadataScanBuilder.scala | 84 ++ .../metadata/GeoParquetMetadataTable.scala | 70 ++ .../src/test/resources/log4j2.properties | 31 + .../sedona/sql/GeoParquetMetadataTests.scala | 152 ++++ ...GeoParquetSpatialFilterPushDownSuite.scala | 347 ++++++++ .../org/apache/sedona/sql/TestBaseScala.scala | 57 ++ .../apache/sedona/sql/geoparquetIOTests.scala | 748 ++++++++++++++++++ spark/spark-3.2/.gitignore | 12 + spark/spark-3.2/pom.xml | 145 ++++ ...pache.spark.sql.sources.DataSourceRegister | 2 + .../parquet/GeoDataSourceUtils.scala | 147 ++++ .../parquet/GeoDateTimeUtils.scala | 43 + .../parquet/GeoParquetFileFormat.scala | 437 ++++++++++ .../parquet/GeoParquetFilters.scala | 678 ++++++++++++++++ .../parquet/GeoParquetReadSupport.scala | 418 ++++++++++ .../GeoParquetRecordMaterializer.scala | 69 ++ .../parquet/GeoParquetRowConverter.scala | 745 +++++++++++++++++ .../parquet/GeoParquetSchemaConverter.scala | 601 ++++++++++++++ .../datasources/parquet/GeoParquetUtils.scala | 127 +++ .../parquet/GeoParquetWriteSupport.scala | 628 +++++++++++++++ .../parquet/GeoSchemaMergeUtils.scala | 107 +++ .../GeoParquetMetadataDataSource.scala | 65 ++ ...arquetMetadataPartitionReaderFactory.scala | 118 +++ .../metadata/GeoParquetMetadataScan.scala | 69 ++ .../GeoParquetMetadataScanBuilder.scala | 84 ++ .../metadata/GeoParquetMetadataTable.scala | 70 ++ .../src/test/resources/log4j2.properties | 31 + .../sedona/sql/GeoParquetMetadataTests.scala | 152 ++++ ...GeoParquetSpatialFilterPushDownSuite.scala | 347 ++++++++ .../org/apache/sedona/sql/TestBaseScala.scala | 57 ++ .../apache/sedona/sql/geoparquetIOTests.scala | 748 ++++++++++++++++++ spark/spark-3.3/.gitignore | 12 + spark/spark-3.3/pom.xml | 145 ++++ ...pache.spark.sql.sources.DataSourceRegister | 2 + .../parquet/GeoDataSourceUtils.scala | 147 ++++ .../parquet/GeoDateTimeUtils.scala | 43 + .../parquet/GeoParquetFileFormat.scala | 437 ++++++++++ .../parquet/GeoParquetFilters.scala | 678 ++++++++++++++++ .../parquet/GeoParquetReadSupport.scala | 418 ++++++++++ .../GeoParquetRecordMaterializer.scala | 69 ++ .../parquet/GeoParquetRowConverter.scala | 745 +++++++++++++++++ .../parquet/GeoParquetSchemaConverter.scala | 601 ++++++++++++++ .../datasources/parquet/GeoParquetUtils.scala | 127 +++ .../parquet/GeoParquetWriteSupport.scala | 628 +++++++++++++++ .../parquet/GeoSchemaMergeUtils.scala | 107 +++ .../GeoParquetMetadataDataSource.scala | 65 ++ ...arquetMetadataPartitionReaderFactory.scala | 118 +++ .../metadata/GeoParquetMetadataScan.scala | 69 ++ .../GeoParquetMetadataScanBuilder.scala | 84 ++ .../metadata/GeoParquetMetadataTable.scala | 70 ++ .../src/test/resources/log4j2.properties | 31 + .../sedona/sql/GeoParquetMetadataTests.scala | 152 ++++ ...GeoParquetSpatialFilterPushDownSuite.scala | 347 ++++++++ .../org/apache/sedona/sql/TestBaseScala.scala | 57 ++ .../apache/sedona/sql/geoparquetIOTests.scala | 748 ++++++++++++++++++ 74 files changed, 17754 insertions(+), 2 deletions(-) create mode 100644 spark/spark-3.1/.gitignore create mode 100644 spark/spark-3.1/pom.xml create mode 100644 spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala create mode 100644 spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala create mode 100644 spark/spark-3.1/src/test/resources/log4j2.properties create mode 100644 spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala create mode 100644 spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala create mode 100644 spark/spark-3.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala create mode 100644 spark/spark-3.1/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala create mode 100644 spark/spark-3.2/.gitignore create mode 100644 spark/spark-3.2/pom.xml create mode 100644 spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala create mode 100644 spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala create mode 100644 spark/spark-3.2/src/test/resources/log4j2.properties create mode 100644 spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala create mode 100644 spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala create mode 100644 spark/spark-3.2/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala create mode 100644 spark/spark-3.2/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala create mode 100644 spark/spark-3.3/.gitignore create mode 100644 spark/spark-3.3/pom.xml create mode 100644 spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala create mode 100644 spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala create mode 100644 spark/spark-3.3/src/test/resources/log4j2.properties create mode 100644 spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala create mode 100644 spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala create mode 100644 spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala create mode 100644 spark/spark-3.3/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala diff --git a/pom.xml b/pom.xml index aa8b2f0f7a..9fbf5db5ec 100644 --- a/pom.xml +++ b/pom.xml @@ -83,7 +83,7 @@ Setting a default value helps IDE:s that can't make sense of profiles. --> 2.12 3.3.0 - 3.0 + 3.3 2.17.2 1.19.0 @@ -684,11 +684,59 @@ true - 3.3.0 + 3.0.0 3.0 2.17.2 + + + sedona-spark-3.1 + + + spark + 3.1 + + true + + + 3.1.0 + 3.1 + 2.17.2 + + + + + sedona-spark-3.2 + + + spark + 3.2 + + true + + + 3.2.0 + 3.2 + 2.17.2 + + + + + sedona-spark-3.3 + + + spark + 3.3 + + true + + + 3.3.0 + 3.3 + 2.17.2 + + sedona-spark-3.4 diff --git a/spark/pom.xml b/spark/pom.xml index ac7750a2c5..482e027ad1 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -48,6 +48,10 @@ enable-all-submodules + spark-3.0 + spark-3.1 + spark-3.2 + spark-3.3 spark-3.4 spark-3.5 diff --git a/spark/spark-3.1/.gitignore b/spark/spark-3.1/.gitignore new file mode 100644 index 0000000000..1cc6c4a1f6 --- /dev/null +++ b/spark/spark-3.1/.gitignore @@ -0,0 +1,12 @@ +/target/ +/.settings/ +/.classpath +/.project +/dependency-reduced-pom.xml +/doc/ +/.idea/ +*.iml +/latest/ +/spark-warehouse/ +/metastore_db/ +*.log diff --git a/spark/spark-3.1/pom.xml b/spark/spark-3.1/pom.xml new file mode 100644 index 0000000000..56f54ae1b2 --- /dev/null +++ b/spark/spark-3.1/pom.xml @@ -0,0 +1,145 @@ + + + + 4.0.0 + + org.apache.sedona + sedona-spark-parent-${spark.compat.version}_${scala.compat.version} + 1.6.1-SNAPSHOT + ../pom.xml + + sedona-spark-3.1_${scala.compat.version} + + ${project.groupId}:${project.artifactId} + A cluster computing system for processing large-scale spatial data: SQL API for Spark 3.1. + http://sedona.apache.org/ + jar + + + false + + + + + org.apache.sedona + sedona-common + ${project.version} + + + com.fasterxml.jackson.core + * + + + + + org.apache.sedona + sedona-spark-common-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + + + org.apache.hadoop + hadoop-client + + + org.apache.logging.log4j + log4j-1.2-api + + + org.geotools + gt-main + + + org.geotools + gt-referencing + + + org.geotools + gt-epsg-hsql + + + org.geotools + gt-geotiff + + + org.geotools + gt-coverage + + + org.geotools + gt-arcgrid + + + org.locationtech.jts + jts-core + + + org.wololo + jts2geojson + + + com.fasterxml.jackson.core + * + + + + + org.scala-lang + scala-library + + + org.scala-lang.modules + scala-collection-compat_${scala.compat.version} + + + org.scalatest + scalatest_${scala.compat.version} + + + org.mockito + mockito-inline + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..e5f994e203 --- /dev/null +++ b/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,2 @@ +org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala new file mode 100644 index 0000000000..4348325570 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.RebaseDateTime +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.util.Utils + +import scala.util.Try + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDataSourceUtils { + + val PARQUET_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + val PARQUET_INT96_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_INT96_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + + private def firstAvailableConf(confs: String*): String = { + confs.find(c => Try(SQLConf.get.getConfString(c)).isSuccess).get + } + + def datetimeRebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def int96RebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def creteDateRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteDateRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteTimestampRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + + def creteTimestampRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala new file mode 100644 index 0000000000..bf3c2a19a9 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDateTimeUtils { + + /** + * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have + * microseconds precision, so this conversion is lossy. + */ + def microsToMillis(micros: Long): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floorDiv(micros, MICROS_PER_MILLIS) + } + + /** + * Converts milliseconds since the epoch to microseconds. + */ + def millisToMicros(millis: Long): Long = { + Math.multiplyExact(millis, MICROS_PER_MILLIS) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala new file mode 100644 index 0000000000..702c6f31fb --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.readParquetFootersInParallel +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +import java.net.URI +import scala.collection.JavaConverters._ +import scala.util.Failure +import scala.util.Try + +class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) + extends ParquetFileFormat + with GeoParquetFileFormatBase + with FileFormat + with DataSourceRegister + with Logging + with Serializable { + + def this() = this(None) + + override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat] && + other.asInstanceOf[GeoParquetFileFormat].spatialFilter == spatialFilter + + override def hashCode(): Int = getClass.hashCode() + + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = + new GeoParquetFileFormat(Some(spatialFilter)) + + override def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + GeoParquetUtils.inferSchema(sparkSession, parameters, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) + + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[OutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo( + "Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo( + "Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[OutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // This metadata is useful for keeping UDTs like Vector/Matrix. + ParquetWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + + conf.set( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + + try { + val fieldIdWriteEnabled = + SQLConf.get.getConfString("spark.sql.parquet.fieldId.write.enabled") + conf.set("spark.sql.parquet.fieldId.write.enabled", fieldIdWriteEnabled) + } catch { + case e: NoSuchElementException => () + } + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) + + // SPARK-15719: Disables writing Parquet summary files by default. + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) + } + + if (ParquetOutputFormat.getJobSummaryLevel(conf) != JobSummaryLevel.NONE + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning( + s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") + } + + conf.set(ParquetOutputFormat.WRITE_SUPPORT_CLASS, classOf[GeoParquetWriteSupport].getName) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } + } + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, requiredSchema.json) + hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader: Boolean = + sqlConf.parquetVectorizedReaderEnabled && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + assert(file.partitionValues.numFields == partitionSchema.size) + + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + + val sharedConf = broadcastedHadoopConf.value.value + + val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new GeoParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(_)) + .reduceOption(FilterApi.and) + } else { + None + } + + // Prune file scans using pushed down spatial filters and per-column bboxes in geoparquet metadata + val shouldScanFile = + GeoParquetMetaData.parseKeyValueMetaData(footerFileMetaData.getKeyValueMetaData).forall { + metadata => spatialFilter.forall(_.evaluate(metadata.columns)) + } + if (!shouldScanFile) { + // The entire file is pruned so that we don't need to scan this file. + Seq.empty[InternalRow].iterator + } else { + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_READ)) + val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_READ)) + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = + new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + logWarning( + s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr") + } + logDebug(s"Falling back to parquet-mr") + // ParquetRecordReader returns InternalRow + val readSupport = new GeoParquetReadSupport( + convertTz, + enableVectorizedReader = false, + datetimeRebaseMode, + int96RebaseMode, + options) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val iter = new RecordReaderIterator[InternalRow](reader) + // SPARK-23457 Register a task completion listener before `initialization`. + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + reader.initialize(split, hadoopAttemptContext) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } + } + } + + override def supportDataType(dataType: DataType): Boolean = super.supportDataType(dataType) + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = false +} + +object GeoParquetFileFormat extends Logging { + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block + * of that file. Thus we only need to retrieve the location of the last block. However, + * Hadoop `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on S3 + * nodes). + */ + def mergeSchemasInParallel( + parameters: Map[String, String], + filesToTouch: Seq[FileStatus], + sparkSession: SparkSession): Option[StructType] = { + val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + + val reader = (files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) => { + readParquetFootersInParallel(conf, files, ignoreCorruptFiles) + .map { footer => + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val keyValueMetaData = footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData + val converter = new GeoParquetToSparkSchemaConverter( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + parameters = parameters) + readSchemaFromFooter(footer, keyValueMetaData, converter, parameters) + } + } + + GeoSchemaMergeUtils.mergeSchemasInParallel(sparkSession, parameters, filesToTouch, reader) + } + + private def readSchemaFromFooter( + footer: Footer, + keyValueMetaData: java.util.Map[String, String], + converter: GeoParquetToSparkSchemaConverter, + parameters: Map[String, String]): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData.getKeyValueMetaData.asScala.toMap + .get(ParquetReadSupport.SPARK_METADATA_KEY) + .flatMap(schema => deserializeSchemaString(schema, keyValueMetaData, parameters)) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString( + schemaString: String, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + val schemaOpt = Try(DataType.fromJson(schemaString).asInstanceOf[StructType]) + .recover { case _: Throwable => + logInfo( + "Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + LegacyTypeStringParser.parseString(schemaString).asInstanceOf[StructType] + } + .recoverWith { case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", + cause) + Failure(cause) + } + .toOption + + schemaOpt.map(schema => + replaceGeometryColumnWithGeometryUDT(schema, keyValueMetaData, parameters)) + } + + private def replaceGeometryColumnWithGeometryUDT( + schema: StructType, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): StructType = { + val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + val fields = schema.fields.map { field => + field.dataType match { + case _: BinaryType if geoParquetMetaData.columns.contains(field.name) => + field.copy(dataType = GeometryUDT) + case _ => field + } + } + StructType(fields) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala new file mode 100644 index 0000000000..d44f679058 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ + +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +// Needed by Sedona to support Spark 3.0 - 3.3 +/** + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class GeoParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + case p: PrimitiveType => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getOriginalType, + p.getPrimitiveTypeName, + p.getTypeLength, + p.getDecimalMetadata))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + (field.fieldNames.toSeq.quoted, field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) + + private def dateToDays(date: Any): Int = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + + private def timestampToMicros(v: Any): JLong = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = GeoDateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct + .flatMap { v => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala new file mode 100644 index 0000000000..a3c2be5d22 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +import java.time.ZoneId +import java.util.{Locale, Map => JMap} +import scala.collection.JavaConverters._ + +/** + * A Parquet [[ReadSupport]] implementation for reading Parquet records as Catalyst + * [[InternalRow]]s. + * + * The API interface of [[ReadSupport]] is a little bit over complicated because of historical + * reasons. In older versions of parquet-mr (say 1.6.0rc3 and prior), [[ReadSupport]] need to be + * instantiated and initialized twice on both driver side and executor side. The [[init()]] method + * is for driver side initialization, while [[prepareForRead()]] is for executor side. However, + * starting from parquet-mr 1.6.0, it's no longer the case, and [[ReadSupport]] is only + * instantiated and initialized on executor side. So, theoretically, now it's totally fine to + * combine these two methods into a single initialization method. The only reason (I could think + * of) to still have them here is for parquet-mr API backwards-compatibility. + * + * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from + * [[init()]] to [[prepareForRead()]], but use a private `var` for simplicity. + */ +class GeoParquetReadSupport( + override val convertTz: Option[ZoneId], + enableVectorizedReader: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends ParquetReadSupport + with Logging { + private var catalystRequestedSchema: StructType = _ + + /** + * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record + * readers. Responsible for figuring out Parquet requested schema used for column pruning. + */ + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + catalystRequestedSchema = { + val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + assert(schemaString != null, "Parquet requested schema not set.") + StructType.fromString(schemaString) + } + + val caseSensitive = + conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) + val schemaPruningEnabled = conf.getBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) + val parquetFileSchema = context.getFileSchema + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema( + parquetFileSchema, + catalystRequestedSchema, + caseSensitive) + + // We pass two schema to ParquetRecordMaterializer: + // - parquetRequestedSchema: the schema of the file data we want to read + // - catalystRequestedSchema: the schema of the rows we want to return + // The reader is responsible for reconciling the differences between the two. + val parquetRequestedSchema = if (schemaPruningEnabled && !enableVectorizedReader) { + // Parquet-MR reader requires that parquetRequestedSchema include only those fields present + // in the underlying parquetFileSchema. Therefore, we intersect the parquetClippedSchema + // with the parquetFileSchema + GeoParquetReadSupport + .intersectParquetGroups(parquetClippedSchema, parquetFileSchema) + .map(groupType => new MessageType(groupType.getName, groupType.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + // Spark's vectorized reader only support atomic types currently. It also skip fields + // in parquetRequestedSchema which are not present in the file. + parquetClippedSchema + } + logDebug( + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$parquetFileSchema + |Parquet clipped schema: + |$parquetClippedSchema + |Parquet requested schema: + |$parquetRequestedSchema + |Catalyst requested schema: + |${catalystRequestedSchema.treeString} + """.stripMargin) + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new GeoParquetRecordMaterializer( + parquetRequestedSchema, + GeoParquetReadSupport.expandUDT(catalystRequestedSchema), + new GeoParquetToSparkSchemaConverter(keyValueMetaData, conf, parameters), + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters) + } +} + +object GeoParquetReadSupport extends Logging { + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist in + * `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = + clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema, caseSensitive) + if (clippedParquetFields.isEmpty) { + ParquetSchemaConverter.EMPTY_MESSAGE + } else { + Types + .buildMessage() + .addFields(clippedParquetFields: _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + } + + private def clipParquetType( + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + + case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. + */ + private def clipParquetListType( + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType, caseSensitive) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList + .getType(0) + .isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if (repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple") { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], + * or a [[StructType]]. + */ + private def clipParquetMapType( + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A clipped [[GroupType]], which has at least one field. + * @note + * Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): Seq[Type] = { + val toParquet = new SparkToGeoParquetSchemaConverter(writeLegacyParquetFormat = false) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException( + s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + } + .getOrElse(toParquet.convertField(f)) + } + } + } + + /** + * Computes the structural intersection between two Parquet group types. This is used to create + * a requestedSchema for ReadContext of Parquet-MR reader. Parquet-MR reader does not support + * the nested field access to non-existent field while parquet library does support to read the + * non-existent field by regular field access. + */ + private def intersectParquetGroups( + groupType1: GroupType, + groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + val field2 = groupType2.getType(field1.getName) + if (field2.isPrimitive) { + None + } else { + intersectParquetGroups(field1, field2.asGroupType) + } + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + + def expandUDT(schema: StructType): StructType = { + def expand(dataType: DataType): DataType = { + dataType match { + case t: ArrayType => + t.copy(elementType = expand(t.elementType)) + + case t: MapType => + t.copy(keyType = expand(t.keyType), valueType = expand(t.valueType)) + + case t: StructType => + val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType))) + t.copy(fields = expandedFields) + + // Don't expand GeometryUDT types. We'll treat geometry columns specially in + // GeoParquetRowConverter + case t: GeometryUDT => t + + case t: UserDefinedType[_] => + t.sqlType + + case t => + t + } + } + + expand(schema).asInstanceOf[StructType] + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala new file mode 100644 index 0000000000..dedbb237b5 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.time.ZoneId +import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.types.StructType + +/** + * A [[RecordMaterializer]] for Catalyst rows. + * + * @param parquetSchema + * Parquet schema of the records to be read + * @param catalystSchema + * Catalyst schema of the rows to be constructed + * @param schemaConverter + * A Parquet-Catalyst schema converter that helps initializing row converters + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseSpec + * the specification of rebasing date/timestamp from Julian to Proleptic Gregorian calendar: + * mode + optional original time zone + * @param int96RebaseSpec + * the specification of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + */ +class GeoParquetRecordMaterializer( + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: GeoParquetToSparkSchemaConverter, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends RecordMaterializer[InternalRow] { + private val rootConverter = new GeoParquetRowConverter( + schemaConverter, + parquetSchema, + catalystSchema, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + NoopUpdater) + + override def getCurrentRecord: InternalRow = rootConverter.currentRecord + + override def getRootConverter: GroupConverter = rootConverter +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala new file mode 100644 index 0000000000..2f2eea38cd --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala @@ -0,0 +1,745 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.{GroupType, OriginalType, Type} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.locationtech.jts.io.WKBReader + +import java.math.{BigDecimal, BigInteger} +import java.time.{ZoneId, ZoneOffset} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +/** + * A [[ParquetRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root + * converter. Take the following Parquet type as an example: + * {{{ + * message root { + * required int32 f1; + * optional group f2 { + * required double f21; + * optional binary f22 (utf8); + * } + * } + * }}} + * 5 converters will be created: + * + * - a root [[ParquetRowConverter]] for [[org.apache.parquet.schema.MessageType]] `root`, which + * contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.OriginalType.INT_32]] field `f1`, and + * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE]] field `f21`, and + * - a [[ParquetStringConverter]] for optional + * [[org.apache.parquet.schema.OriginalType.UTF8]] string field `f22` + * + * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have + * any "parent" container. + * + * @param schemaConverter + * A utility converter used to convert Parquet types to Catalyst types. + * @param parquetType + * Parquet schema of Parquet records + * @param catalystType + * Spark SQL schema that corresponds to the Parquet record type. User-defined types other than + * [[GeometryUDT]] should have been expanded. + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseMode + * the mode of rebasing date/timestamp from Julian to Proleptic Gregorian calendar + * @param int96RebaseMode + * the mode of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + * @param updater + * An updater which propagates converted field values to the parent container + */ +private[parquet] class GeoParquetRowConverter( + schemaConverter: GeoParquetToSparkSchemaConverter, + parquetType: GroupType, + catalystType: StructType, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String], + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) + with Logging { + + assert( + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + + assert( + !catalystType.existsRecursively(t => + !t.isInstanceOf[GeometryUDT] && t.isInstanceOf[UserDefinedType[_]]), + s"""User-defined types in Catalyst schema should have already been expanded: + |${catalystType.prettyJson} + """.stripMargin) + + logDebug(s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) + + /** + * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates + * converted filed values to the `ordinal`-th cell in `currentRow`. + */ + private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater { + override def set(value: Any): Unit = row(ordinal) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(value: Short): Unit = row.setShort(ordinal, value) + override def setInt(value: Int): Unit = row.setInt(ordinal, value) + override def setLong(value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) + } + + private[this] val currentRow = new SpecificInternalRow(catalystType.map(_.dataType)) + + /** + * The [[InternalRow]] converted from an entire Parquet record. + */ + def currentRecord: InternalRow = currentRow + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(int96RebaseMode, "Parquet INT96") + + // Converters for each field. + private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { + // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false + // to prevent throwing IllegalArgumentException when searching catalyst type's field index + val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { + catalystType.fieldNames.zipWithIndex.toMap + } else { + CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + } + parquetType.getFields.asScala.map { parquetField => + val fieldIndex = catalystFieldNameToIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + }.toArray + } + + // Updaters for each field. + private[this] val fieldUpdaters: Array[ParentContainerUpdater] = fieldConverters.map(_.updater) + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def end(): Unit = { + var i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).end() + i += 1 + } + updater.set(currentRow) + } + + override def start(): Unit = { + var i = 0 + val numFields = currentRow.numFields + while (i < numFields) { + currentRow.setNullAt(i) + i += 1 + } + i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).start() + i += 1 + } + } + + /** + * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type + * `catalystType`. Converted values are handled by `updater`. + */ + private def newConverter( + parquetType: Type, + catalystType: DataType, + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { + + catalystType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => + new ParquetPrimitiveConverter(updater) + + case GeometryUDT => + if (parquetType.isPrimitive) { + new ParquetPrimitiveConverter(updater) { + override def addBinary(value: Binary): Unit = { + val wkbReader = new WKBReader() + val geom = wkbReader.read(value.getBytes) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + if (GeoParquetUtils.isLegacyMode(parameters)) { + new ParquetArrayConverter( + parquetType.asGroupType(), + ArrayType(ByteType, containsNull = false), + updater) { + override def end(): Unit = { + val wkbReader = new WKBReader() + val byteArray = currentArray.map(_.asInstanceOf[Byte]).toArray + val geom = wkbReader.read(byteArray) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + throw new IllegalArgumentException( + s"Parquet type for geometry column is $parquetType. This parquet file could be written by " + + "Apache Sedona <= 1.3.1-incubating. Please use option(\"legacyMode\", \"true\") to read this file.") + } + } + + case ByteType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setByte(value.asInstanceOf[ByteType#InternalType]) + + override def addBinary(value: Binary): Unit = { + val bytes = value.getBytes + for (b <- bytes) { + updater.set(b) + } + } + } + + case ShortType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setShort(value.asInstanceOf[ShortType#InternalType]) + } + + // For INT32 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For INT64 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals + case t: DecimalType + if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || + parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => + new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + case t: DecimalType => + throw new RuntimeException( + s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " + + s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " + + "FIXED_LEN_BYTE_ARRAY, or BINARY.") + + case StringType => + new ParquetStringConverter(updater) + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(timestampRebaseFunc(value)) + } + } + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + val micros = GeoDateTimeUtils.millisToMicros(value) + updater.setLong(timestampRebaseFunc(micros)) + } + } + + // INT96 timestamp doesn't have a logical type, here we check the physical type instead. + case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 => + new ParquetPrimitiveConverter(updater) { + // Converts nanosecond timestamps stored as INT96 + override def addBinary(value: Binary): Unit = { + val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = convertTz + .map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + updater.setLong(adjTime) + } + } + + case DateType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + updater.set(dateRebaseFunc(value)) + } + } + + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + + case t: ArrayType => + new ParquetArrayConverter(parquetType.asGroupType(), t, updater) + + case t: MapType => + new ParquetMapConverter(parquetType.asGroupType(), t, updater) + + case t: StructType => + val wrappedUpdater = { + // SPARK-30338: avoid unnecessary InternalRow copying for nested structs: + // There are two cases to handle here: + // + // 1. Parent container is a map or array: we must make a deep copy of the mutable row + // because this converter may be invoked multiple times per Parquet input record + // (if the map or array contains multiple elements). + // + // 2. Parent container is a struct: we don't need to copy the row here because either: + // + // (a) all ancestors are structs and therefore no copying is required because this + // converter will only be invoked once per Parquet input record, or + // (b) some ancestor is struct that is nested in a map or array and that ancestor's + // converter will perform deep-copying (which will recursively copy this row). + if (updater.isInstanceOf[RowUpdater]) { + // `updater` is a RowUpdater, implying that the parent container is a struct. + updater + } else { + // `updater` is NOT a RowUpdater, implying that the parent container a map or array. + new ParentContainerUpdater { + override def set(value: Any): Unit = { + updater.set(value.asInstanceOf[SpecificInternalRow].copy()) // deep copy + } + } + } + } + new GeoParquetRowConverter( + schemaConverter, + parquetType.asGroupType(), + t, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + wrappedUpdater) + + case t => + throw new RuntimeException( + s"Unable to create Parquet converter for data type ${t.json} " + + s"whose Parquet type is $parquetType") + } + } + + /** + * Parquet converter for strings. A dictionary is used to minimize string decoding cost. + */ + private final class ParquetStringConverter(updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + private var expandedDictionary: Array[UTF8String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i => + UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes) + } + } + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + override def addBinary(value: Binary): Unit = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we + // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying + // it. + val buffer = value.toByteBuffer + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() + updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) + } + } + + /** + * Parquet converter for fixed-precision decimals. + */ + private abstract class ParquetDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + protected var expandedDictionary: Array[Decimal] = _ + + override def hasDictionarySupport: Boolean = true + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + // Converts decimals stored as INT32 + override def addInt(value: Int): Unit = { + addLong(value: Long) + } + + // Converts decimals stored as INT64 + override def addLong(value: Long): Unit = { + updater.set(decimalFromLong(value)) + } + + // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY + override def addBinary(value: Binary): Unit = { + updater.set(decimalFromBinary(value)) + } + + protected def decimalFromLong(value: Long): Decimal = { + Decimal(value, precision, scale) + } + + protected def decimalFromBinary(value: Binary): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + val unscaled = ParquetRowConverter.binaryToUnscaledLong(value) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) + } + } + } + + private class ParquetIntDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToInt(id).toLong) + } + } + } + + private class ParquetLongDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToLong(id)) + } + } + } + + private class ParquetBinaryDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromBinary(dictionary.decodeToBinary(id)) + } + } + } + + /** + * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard + * Parquet lists are represented as a 3-level group annotated by `LIST`: + * {{{ + * group (LIST) { <-- parquetSchema points here + * repeated group list { + * element; + * } + * } + * }}} + * The `parquetSchema` constructor argument points to the outermost group. + * + * However, before this representation is standardized, some Parquet libraries/tools also use + * some non-standard formats to represent list-like structures. Backwards-compatibility rules + * for handling these cases are described in Parquet format spec. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + */ + private class ParquetArrayConverter( + parquetSchema: GroupType, + catalystSchema: ArrayType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + protected[this] val currentArray: mutable.ArrayBuffer[Any] = ArrayBuffer.empty[Any] + + private[this] val elementConverter: Converter = { + val repeatedType = parquetSchema.getType(0) + val elementType = catalystSchema.elementType + + // At this stage, we're not sure whether the repeated field maps to the element type or is + // just the syntactic repeated group of the 3-level standard LIST layout. Take the following + // Parquet LIST-annotated group type as an example: + // + // optional group f (LIST) { + // repeated group list { + // optional group element { + // optional int32 element; + // } + // } + // } + // + // This type is ambiguous: + // + // 1. When interpreted as a standard 3-level layout, the `list` field is just the syntactic + // group, and the entire type should be translated to: + // + // ARRAY> + // + // 2. On the other hand, when interpreted as a non-standard 2-level layout, the `list` field + // represents the element type, and the entire type should be translated to: + // + // ARRAY>> + // + // Here we try to convert field `list` into a Catalyst type to see whether the converted type + // matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise, + // it's case 2. + val guessedElementType = schemaConverter.convertFieldWithGeo(repeatedType) + + if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) { + // If the repeated field corresponds to the element type, creates a new converter using the + // type of the repeated field. + newConverter( + repeatedType, + elementType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentArray += value + }) + } else { + // If the repeated field corresponds to the syntactic group in the standard 3-level Parquet + // LIST layout, creates a new converter using the only child field of the repeated field. + assert(!repeatedType.isPrimitive && repeatedType.asGroupType().getFieldCount == 1) + new ElementConverter(repeatedType.asGroupType().getType(0), elementType) + } + } + + override def getConverter(fieldIndex: Int): Converter = elementConverter + + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + + override def start(): Unit = currentArray.clear() + + /** Array element converter */ + private final class ElementConverter(parquetType: Type, catalystType: DataType) + extends GroupConverter { + + private var currentElement: Any = _ + + private[this] val converter = + newConverter( + parquetType, + catalystType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentElement = value + }) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = currentArray += currentElement + + override def start(): Unit = currentElement = null + } + } + + /** Parquet converter for maps */ + private final class ParquetMapConverter( + parquetType: GroupType, + catalystType: MapType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + private[this] val currentKeys = ArrayBuffer.empty[Any] + private[this] val currentValues = ArrayBuffer.empty[Any] + + private[this] val keyValueConverter = { + val repeatedType = parquetType.getType(0).asGroupType() + new KeyValueConverter( + repeatedType.getType(0), + repeatedType.getType(1), + catalystType.keyType, + catalystType.valueType) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override def end(): Unit = { + // The parquet map may contains null or duplicated map keys. When it happens, the behavior is + // undefined. + // TODO (SPARK-26174): disallow it with a config. + updater.set( + new ArrayBasedMapData( + new GenericArrayData(currentKeys.toArray), + new GenericArrayData(currentValues.toArray))) + } + + override def start(): Unit = { + currentKeys.clear() + currentValues.clear() + } + + /** Parquet converter for key-value pairs within the map. */ + private final class KeyValueConverter( + parquetKeyType: Type, + parquetValueType: Type, + catalystKeyType: DataType, + catalystValueType: DataType) + extends GroupConverter { + + private var currentKey: Any = _ + + private var currentValue: Any = _ + + private[this] val converters = Array( + // Converter for keys + newConverter( + parquetKeyType, + catalystKeyType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentKey = value + }), + + // Converter for values + newConverter( + parquetValueType, + catalystValueType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = { + currentKeys += currentKey + currentValues += currentValue + } + + override def start(): Unit = { + currentKey = null + currentValue = null + } + } + } + + private trait RepeatedConverter { + private[this] val currentArray = ArrayBuffer.empty[Any] + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray.clear() + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter + with RepeatedConverter + with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = + elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter + with HasParentContainerUpdater + with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala new file mode 100644 index 0000000000..eab20875a6 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.checkConversionRequirement +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @param assumeBinaryIsString + * Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields. + * @param assumeInt96IsTimestamp + * Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields. + * @param parameters + * Options for reading GeoParquet files. + */ +class GeoParquetToSparkSchemaConverter( + keyValueMetaData: java.util.Map[String, String], + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + parameters: Map[String, String]) { + + private val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: SQLConf, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + parameters = parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: Configuration, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + parameters = parameters) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.asScala.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertFieldWithGeo(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertFieldWithGeo(field), nullable = false) + + case REPEATED => + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertFieldWithGeo(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) + } + } + + StructType(fields.toSeq) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertFieldWithGeo(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def isGeometryField(fieldName: String): Boolean = + geoParquetMetaData.columns.contains(fieldName) + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotSupported() = + throw new IllegalArgumentException(s"Parquet type not supported: $typeString") + + def typeNotImplemented() = + throw new IllegalArgumentException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new IllegalArgumentException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + ParquetSchemaConverter.checkConversionRequirement( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + originalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + originalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case UINT_64 => typeNotSupported() + case TIMESTAMP_MICROS => TimestampType + case TIMESTAMP_MILLIS => TimestampType + case _ => illegalType() + } + + case INT96 => + ParquetSchemaConverter.checkConversionRequirement( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + originalType match { + case UTF8 | ENUM | JSON => StringType + case null if isGeometryField(field.getName) => GeometryUDT + case null if assumeBinaryIsString => StringType + case null => BinaryType + case BSON => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + originalType match { + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1, + s"Invalid list type $field") + + val repeatedType = field.getType(0) + ParquetSchemaConverter.checkConversionRequirement( + repeatedType.isRepetition(REPEATED), + s"Invalid list type $field") + + if (isElementTypeWithGeo(repeatedType, field.getName)) { + ArrayType(convertFieldWithGeo(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertFieldWithGeo(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + ParquetSchemaConverter.checkConversionRequirement( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertFieldWithGeo(keyType), + convertFieldWithGeo(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new IllegalArgumentException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + def isElementTypeWithGeo(repeatedType: Type, parentName: String): Boolean = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // ARRAY (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } +} + +/** + * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]]. + * + * @param writeLegacyParquetFormat + * Whether to use legacy Parquet format compatible with Spark 1.4 and prior versions when + * converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. When set to false, use + * standard format defined in parquet-format spec. This argument only affects Parquet write + * path. + * @param outputTimestampType + * which parquet timestamp type to use when writing. + */ +class SparkToGeoParquetSchemaConverter( + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96) + extends SparkToParquetSchemaConverter(writeLegacyParquetFormat, outputTimestampType) { + + def this(conf: SQLConf) = this( + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + outputTimestampType = conf.parquetOutputTimestampType) + + def this(conf: Configuration) = this( + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, + outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + override def convert(catalystSchema: StructType): MessageType = { + Types + .buildMessage() + .addFields(catalystSchema.map(convertField): _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + override def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + GeoParquetSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or + // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the + // behavior same as before. + // + // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond + // timestamp in Impala for some historical reasons. It's not recommended to be used for any + // other types and will probably be deprecated in some future version of parquet-format spec. + // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and + // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. + // + // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting + // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can + // store a timestamp into a `Long`. This design decision is subject to change though, for + // example, we may resort to nanosecond precision in the future. + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + Types.primitive(INT96, repetition).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + } + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ====================== + // Decimals (legacy mode) + // ====================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // ======================== + // Decimals (standard mode) + // ======================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // =================================== + // ArrayType and MapType (legacy mode) + // =================================== + + // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level + // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro + // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element + // field name "array" is borrowed from parquet-avro. + case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => + // group (LIST) { + // optional group bag { + // repeated array; + // } + // } + + // This should not use `listOfElements` here because this new method checks if the + // element name is `element` in the `GroupType` and throws an exception if not. + // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with + // `array` as its element name as below. Therefore, we build manually + // the correct group type here via the builder. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .buildGroup(REPEATED) + // "array" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable))) + .named("bag")) + .named(field.name) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => + // group (LIST) { + // repeated element; + // } + + // Here too, we should not use `listOfElements`. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable), REPEATED)) + .named(field.name) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ===================================== + // ArrayType and MapType (standard mode) + // ===================================== + + case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition) + .as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields + .foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + } + .named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new IllegalArgumentException( + s"Unsupported data type ${field.dataType.catalogString}") + } + } +} + +private[sql] object GeoParquetSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + checkConversionRequirement( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ").trim) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala new file mode 100644 index 0000000000..477d744441 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +import scala.language.existentials + +object GeoParquetUtils { + def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val parquetOptions = new ParquetOptions(parameters, sparkSession.sessionState.conf) + val shouldMergeSchemas = parquetOptions.mergeSchema + val mergeRespectSummaries = sparkSession.sessionState.conf.isParquetSchemaRespectSummaries + val filesByType = splitFiles(files) + val filesToTouch = + if (shouldMergeSchemas) { + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq.empty + } else { + filesByType.data + } + needMerged ++ filesByType.metadata ++ filesByType.commonMetadata + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + filesByType.commonMetadata.headOption + // Falls back to "_metadata" + .orElse(filesByType.metadata.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(filesByType.data.headOption) + .toSeq + } + GeoParquetFileFormat.mergeSchemasInParallel(parameters, filesToTouch, sparkSession) + } + + case class FileTypes( + data: Seq[FileStatus], + metadata: Seq[FileStatus], + commonMetadata: Seq[FileStatus]) + + private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = { + val leaves = allFiles.toArray.sortBy(_.getPath.toString) + + FileTypes( + data = leaves.filterNot(f => isSummaryFile(f.getPath)), + metadata = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE), + commonMetadata = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)) + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + /** + * Legacy mode option is for reading Parquet files written by old versions of Apache Sedona (<= + * 1.3.1-incubating). Such files are actually not GeoParquet files and do not have GeoParquet + * file metadata. Geometry fields were encoded as list of bytes and stored as group type in + * Parquet files. The Definition of GeometryUDT before 1.4.0 was: + * {{{ + * case class GeometryUDT extends UserDefinedType[Geometry] { + * override def sqlType: DataType = ArrayType(ByteType, containsNull = false) + * // ... + * }}} + * Since 1.4.0, the sqlType of GeometryUDT is changed to BinaryType. This is a breaking change + * for reading old Parquet files. To read old Parquet files, users need to use "geoparquet" + * format and set legacyMode to true. + * @param parameters + * user provided parameters for reading GeoParquet files using `.option()` method, e.g. + * `spark.read.format("geoparquet").option("legacyMode", "true").load("path")` + * @return + * true if legacyMode is set to true, false otherwise + */ + def isLegacyMode(parameters: Map[String, String]): Boolean = + parameters.getOrElse("legacyMode", "false").toBoolean + + /** + * Parse GeoParquet file metadata from Parquet file metadata. Legacy parquet files do not + * contain GeoParquet file metadata, so we'll simply return an empty GeoParquetMetaData object + * when legacy mode is enabled. + * @param keyValueMetaData + * Parquet file metadata + * @param parameters + * user provided parameters for reading GeoParquet files + * @return + * GeoParquetMetaData object + */ + def parseGeoParquetMetaData( + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): GeoParquetMetaData = { + val isLegacyMode = GeoParquetUtils.isLegacyMode(parameters) + GeoParquetMetaData.parseKeyValueMetaData(keyValueMetaData).getOrElse { + if (isLegacyMode) { + GeoParquetMetaData(None, "", Map.empty) + } else { + throw new IllegalArgumentException("GeoParquet file does not contain valid geo metadata") + } + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala new file mode 100644 index 0000000000..90d6d962f4 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.io.api.Binary +import org.apache.parquet.io.api.RecordConsumer +import org.apache.sedona.common.utils.GeomUtils +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData.{GEOPARQUET_COVERING_KEY, GEOPARQUET_CRS_KEY, GEOPARQUET_VERSION_KEY, VERSION, createCoveringColumnMetadata} +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetWriteSupport.GeometryColumnInfo +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.json4s.{DefaultFormats, Extraction, JValue} +import org.json4s.jackson.JsonMethods.parse +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKBWriter + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * A Parquet [[WriteSupport]] implementation that writes Catalyst [[InternalRow]]s as Parquet + * messages. This class can write Parquet data in two modes: + * + * - Standard mode: Parquet data are written in standard format defined in parquet-format spec. + * - Legacy mode: Parquet data are written in legacy format compatible with Spark 1.4 and prior. + * + * This behavior can be controlled by SQL option `spark.sql.parquet.writeLegacyFormat`. The value + * of this option is propagated to this class by the `init()` method and its Hadoop configuration + * argument. + */ +class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. + // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access + // data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // Schema of the `InternalRow`s to be written + private var schema: StructType = _ + + // `ValueWriter`s for all fields of the schema + private var rootFieldWriters: Array[ValueWriter] = _ + + // The Parquet `RecordConsumer` to which all `InternalRow`s are written + private var recordConsumer: RecordConsumer = _ + + // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions + private var writeLegacyParquetFormat: Boolean = _ + + // Which parquet timestamp type to use when writing. + private var outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = _ + + // Reusable byte array used to write timestamps as Parquet INT96 values + private val timestampBuffer = new Array[Byte](12) + + // Reusable byte array used to write decimal values + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) + + private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_WRITE)) + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val int96RebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_WRITE)) + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(int96RebaseMode, "Parquet INT96") + + // A mapping from geometry field ordinal to bounding box. According to the geoparquet specification, + // "Geometry columns MUST be at the root of the schema", so we don't need to worry about geometry + // fields in nested structures. + private val geometryColumnInfoMap: mutable.Map[Int, GeometryColumnInfo] = mutable.Map.empty + + private var geoParquetVersion: Option[String] = None + private var defaultGeoParquetCrs: Option[JValue] = None + private val geoParquetColumnCrsMap: mutable.Map[String, Option[JValue]] = mutable.Map.empty + private val geoParquetColumnCoveringMap: mutable.Map[String, Covering] = mutable.Map.empty + + override def init(configuration: Configuration): WriteContext = { + val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) + this.schema = StructType.fromString(schemaString) + this.writeLegacyParquetFormat = { + // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation + assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) + configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean + } + + this.outputTimestampType = { + val key = SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key + assert(configuration.get(key) != null) + SQLConf.ParquetOutputTimestampType.withName(configuration.get(key)) + } + + this.rootFieldWriters = schema.zipWithIndex + .map { case (field, ordinal) => + makeWriter(field.dataType, Some(ordinal)) + } + .toArray[ValueWriter] + + if (geometryColumnInfoMap.isEmpty) { + throw new RuntimeException("No geometry column found in the schema") + } + + geoParquetVersion = configuration.get(GEOPARQUET_VERSION_KEY) match { + case null => Some(VERSION) + case version: String => Some(version) + } + defaultGeoParquetCrs = configuration.get(GEOPARQUET_CRS_KEY) match { + case null => + // If no CRS is specified, we write null to the crs metadata field. This is for compatibility with + // geopandas 0.10.0 and earlier versions, which requires crs field to be present. + Some(org.json4s.JNull) + case "" => None + case crs: String => Some(parse(crs)) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_CRS_KEY + "." + name)).foreach { + case "" => geoParquetColumnCrsMap.put(name, None) + case crs: String => geoParquetColumnCrsMap.put(name, Some(parse(crs))) + } + } + Option(configuration.get(GEOPARQUET_COVERING_KEY)).foreach { coveringColumnName => + if (geometryColumnInfoMap.size > 1) { + throw new IllegalArgumentException( + s"$GEOPARQUET_COVERING_KEY is ambiguous when there are multiple geometry columns." + + s"Please specify $GEOPARQUET_COVERING_KEY. for configured geometry column.") + } + val geometryColumnName = schema(geometryColumnInfoMap.keys.head).name + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(geometryColumnName, covering) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_COVERING_KEY + "." + name)).foreach { + coveringColumnName => + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(name, covering) + } + } + + val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) + val sparkSqlParquetRowMetadata = GeoParquetWriteSupport.getSparkSqlParquetRowMetadata(schema) + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> sparkSqlParquetRowMetadata) ++ { + if (datetimeRebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyDateTime" -> "") + } else { + None + } + } ++ { + if (int96RebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyINT96" -> "") + } else { + None + } + } + + logInfo(s"""Initialized Parquet WriteSupport with Catalyst schema: + |${schema.prettyJson} + |and corresponding Parquet message type: + |$messageType + """.stripMargin) + + new WriteContext(messageType, metadata.asJava) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + + override def finalizeWrite(): WriteSupport.FinalizedWriteContext = { + val metadata = new util.HashMap[String, String]() + if (geometryColumnInfoMap.nonEmpty) { + val primaryColumnIndex = geometryColumnInfoMap.keys.head + val primaryColumn = schema.fields(primaryColumnIndex).name + val columns = geometryColumnInfoMap.map { case (ordinal, columnInfo) => + val columnName = schema.fields(ordinal).name + val geometryTypes = columnInfo.seenGeometryTypes.toSeq + val bbox = if (geometryTypes.nonEmpty) { + Seq( + columnInfo.bbox.minX, + columnInfo.bbox.minY, + columnInfo.bbox.maxX, + columnInfo.bbox.maxY) + } else Seq(0.0, 0.0, 0.0, 0.0) + val crs = geoParquetColumnCrsMap.getOrElse(columnName, defaultGeoParquetCrs) + val covering = geoParquetColumnCoveringMap.get(columnName) + columnName -> GeometryFieldMetaData("WKB", geometryTypes, bbox, crs, covering) + }.toMap + val geoParquetMetadata = GeoParquetMetaData(geoParquetVersion, primaryColumn, columns) + val geoParquetMetadataJson = GeoParquetMetaData.toJson(geoParquetMetadata) + metadata.put("geo", geoParquetMetadataJson) + } + new FinalizedWriteContext(metadata) + } + + override def write(row: InternalRow): Unit = { + consumeMessage { + writeFields(row, schema, rootFieldWriters) + } + } + + private def writeFields( + row: InternalRow, + schema: StructType, + fieldWriters: Array[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + if (!row.isNullAt(i)) { + consumeField(schema(i).name, i) { + fieldWriters(i).apply(row, i) + } + } + i += 1 + } + } + + private def makeWriter(dataType: DataType, rootOrdinal: Option[Int] = None): ValueWriter = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getShort(ordinal)) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(dateRebaseFunc(row.getInt(ordinal))) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addFloat(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addDouble(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary( + Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) + + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + (row: SpecializedGetters, ordinal: Int) => + val micros = int96RebaseFunc(row.getLong(ordinal)) + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(micros) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + recordConsumer.addLong(timestampRebaseFunc(micros)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + val millis = GeoDateTimeUtils.microsToMillis(timestampRebaseFunc(micros)) + recordConsumer.addLong(millis) + } + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal))) + + case DecimalType.Fixed(precision, scale) => + makeDecimalWriter(precision, scale) + + case t: StructType => + val fieldWriters = t.map(_.dataType).map(makeWriter(_, None)).toArray[ValueWriter] + (row: SpecializedGetters, ordinal: Int) => + consumeGroup { + writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) + } + + case t: ArrayType => makeArrayWriter(t) + + case t: MapType => makeMapWriter(t) + + case GeometryUDT => + val geometryColumnInfo = rootOrdinal match { + case Some(ordinal) => + geometryColumnInfoMap.getOrElseUpdate(ordinal, new GeometryColumnInfo()) + case None => null + } + (row: SpecializedGetters, ordinal: Int) => { + val serializedGeometry = row.getBinary(ordinal) + val geom = GeometryUDT.deserialize(serializedGeometry) + val wkbWriter = new WKBWriter(GeomUtils.getDimension(geom)) + recordConsumer.addBinary(Binary.fromReusedByteArray(wkbWriter.write(geom))) + if (geometryColumnInfo != null) { + geometryColumnInfo.update(geom) + } + } + + case t: UserDefinedType[_] => makeWriter(t.sqlType) + + // TODO Adds IntervalType support + case _ => sys.error(s"Unsupported data type $dataType.") + } + } + + private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { + assert( + precision <= DecimalType.MAX_PRECISION, + s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") + + val numBytes = Decimal.minBytesForPrecision(precision) + + val int32Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addInteger(unscaledLong.toInt) + } + + val int64Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addLong(unscaledLong) + } + + val binaryWriterUsingUnscaledLong = + (row: SpecializedGetters, ordinal: Int) => { + // When the precision is low enough (<= 18) to squeeze the decimal value into a `Long`, we + // can build a fixed-length byte array with length `numBytes` using the unscaled `Long` + // value and the `decimalBuffer` for better performance. + val unscaled = row.getDecimal(ordinal, precision, scale).toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + + while (i < numBytes) { + decimalBuffer(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(decimalBuffer, 0, numBytes)) + } + + val binaryWriterUsingUnscaledBytes = + (row: SpecializedGetters, ordinal: Int) => { + val decimal = row.getDecimal(ordinal, precision, scale) + val bytes = decimal.toJavaBigDecimal.unscaledValue().toByteArray + val fixedLengthBytes = if (bytes.length == numBytes) { + // If the length of the underlying byte array of the unscaled `BigInteger` happens to be + // `numBytes`, just reuse it, so that we don't bother copying it to `decimalBuffer`. + bytes + } else { + // Otherwise, the length must be less than `numBytes`. In this case we copy contents of + // the underlying bytes with padding sign bytes to `decimalBuffer` to form the result + // fixed-length byte array. + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(fixedLengthBytes, 0, numBytes)) + } + + writeLegacyParquetFormat match { + // Standard mode, 1 <= precision <= 9, writes as INT32 + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer + + // Standard mode, 10 <= precision <= 18, writes as INT64 + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer + + // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong + + // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY + case _ => binaryWriterUsingUnscaledBytes + } + } + + def makeArrayWriter(arrayType: ArrayType): ValueWriter = { + val elementWriter = makeWriter(arrayType.elementType) + + def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < array.numElements()) { + consumeGroup { + // Only creates the element field if the current array element is not null. + if (!array.isNullAt(i)) { + consumeField(elementFieldName, 0) { + elementWriter.apply(array, i) + } + } + } + i += 1 + } + } + } + } + } + + def twoLevelArrayWriter(repeatedFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedFieldName, 0) { + var i = 0 + while (i < array.numElements()) { + elementWriter.apply(array, i) + i += 1 + } + } + } + } + } + + (writeLegacyParquetFormat, arrayType.containsNull) match { + case (legacyMode @ false, _) => + // Standard mode: + // + // group (LIST) { + // repeated group list { + // ^~~~ repeatedGroupName + // element; + // ^~~~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "list", elementFieldName = "element") + + case (legacyMode @ true, nullableElements @ true) => + // Legacy mode, with nullable elements: + // + // group (LIST) { + // optional group bag { + // ^~~ repeatedGroupName + // repeated array; + // ^~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "bag", elementFieldName = "array") + + case (legacyMode @ true, nullableElements @ false) => + // Legacy mode, with non-nullable elements: + // + // group (LIST) { + // repeated array; + // ^~~~~ repeatedFieldName + // } + twoLevelArrayWriter(repeatedFieldName = "array") + } + } + + private def makeMapWriter(mapType: MapType): ValueWriter = { + val keyWriter = makeWriter(mapType.keyType) + val valueWriter = makeWriter(mapType.valueType) + val repeatedGroupName = if (writeLegacyParquetFormat) { + // Legacy mode: + // + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // ^~~ repeatedGroupName + // required key; + // value; + // } + // } + "map" + } else { + // Standard mode: + // + // group (MAP) { + // repeated group key_value { + // ^~~~~~~~~ repeatedGroupName + // required key; + // value; + // } + // } + "key_value" + } + + (row: SpecializedGetters, ordinal: Int) => { + val map = row.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + + consumeGroup { + // Only creates the repeated field if the map is non-empty. + if (map.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < map.numElements()) { + consumeGroup { + consumeField("key", 0) { + keyWriter.apply(keyArray, i) + } + + // Only creates the "value" field if the value if non-empty + if (!map.valueArray().isNullAt(i)) { + consumeField("value", 1) { + valueWriter.apply(valueArray, i) + } + } + } + i += 1 + } + } + } + } + } + } + + private def consumeMessage(f: => Unit): Unit = { + recordConsumer.startMessage() + f + recordConsumer.endMessage() + } + + private def consumeGroup(f: => Unit): Unit = { + recordConsumer.startGroup() + f + recordConsumer.endGroup() + } + + private def consumeField(field: String, index: Int)(f: => Unit): Unit = { + recordConsumer.startField(field, index) + f + recordConsumer.endField(field, index) + } +} + +object GeoParquetWriteSupport { + class GeometryColumnInfo { + val bbox: GeometryColumnBoundingBox = new GeometryColumnBoundingBox() + + // GeoParquet column metadata has a `geometry_types` property, which contains a list of geometry types + // that are present in the column. + val seenGeometryTypes: mutable.Set[String] = mutable.Set.empty + + def update(geom: Geometry): Unit = { + bbox.update(geom) + // In case of 3D geometries, a " Z" suffix gets added (e.g. ["Point Z"]). + val hasZ = { + val coordinate = geom.getCoordinate + if (coordinate != null) !coordinate.getZ.isNaN else false + } + val geometryType = if (!hasZ) geom.getGeometryType else geom.getGeometryType + " Z" + seenGeometryTypes.add(geometryType) + } + } + + class GeometryColumnBoundingBox( + var minX: Double = Double.PositiveInfinity, + var minY: Double = Double.PositiveInfinity, + var maxX: Double = Double.NegativeInfinity, + var maxY: Double = Double.NegativeInfinity) { + def update(geom: Geometry): Unit = { + val env = geom.getEnvelopeInternal + minX = math.min(minX, env.getMinX) + minY = math.min(minY, env.getMinY) + maxX = math.max(maxX, env.getMaxX) + maxY = math.max(maxY, env.getMaxY) + } + } + + private def getSparkSqlParquetRowMetadata(schema: StructType): String = { + val fields = schema.fields.map { field => + field.dataType match { + case _: GeometryUDT => + // Don't write the GeometryUDT type to the Parquet metadata. Write the type as binary for maximum + // compatibility. + field.copy(dataType = BinaryType) + case _ => field + } + } + StructType(fields).json + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala new file mode 100644 index 0000000000..aadca3a60f --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoSchemaMergeUtils { + + def mergeSchemasInParallel( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus], + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]) + : Option[StructType] = { + val serializedConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(parameters)) + + // !! HACK ALERT !! + // Here is a hack for Parquet, but it can be used by Orc as well. + // + // Parquet requires `FileStatus`es to read footers. + // Here we try to send cached `FileStatus`es to executor side to avoid fetching them again. + // However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = files.map(f => (f.getPath.toString, f.getLen)) + + // Set the number of partitions to prevent following schema reads from generating many tasks + // in case of a small number of orc files. + val numParallelism = Math.min( + Math.max(partialFileStatusInfo.size, 1), + sparkSession.sparkContext.defaultParallelism) + + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + + // Issues a Spark job to read Parquet/ORC schema in parallel. + val partiallyMergedSchemas = + sparkSession.sparkContext + .parallelize(partialFileStatusInfo, numParallelism) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + val schemas = schemaReader(fakeFileStatuses, serializedConf.value, ignoreCorruptFiles) + + if (schemas.isEmpty) { + Iterator.empty + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergedSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } + } + .collect() + + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 0000000000..43e1ababb7 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 0000000000..1fe2faa2e0 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + val filePath = partitionedFile.filePath + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath), configuration)) + .getFooter + .getFileMetaData + .getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 0000000000..b86ab7a399 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters) + } + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" + + // This is for compatibility with Spark 3.0. Spark 3.3 does not have this method + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 0000000000..6a25e4530c --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + getPushedDataFilters, + getPartitionFilters, + getDataFilters) + } + + // The following methods uses reflection to address compatibility issues for Spark 3.0 ~ 3.2 + + private def getPushedDataFilters: Array[Filter] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("pushedDataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Array[Filter]] + } catch { + case _: NoSuchFieldException => + Array.empty + } + } + + private def getPartitionFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("partitionFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } + + private def getDataFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("dataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 0000000000..845764fae5 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-3.1/src/test/resources/log4j2.properties b/spark/spark-3.1/src/test/resources/log4j2.properties new file mode 100644 index 0000000000..5f89859463 --- /dev/null +++ b/spark/spark-3.1/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 0000000000..421890c700 --- /dev/null +++ b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT, nullable = false), + StructField("geomColumn2", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala new file mode 100644 index 0000000000..8f3cc3f1e5 --- /dev/null +++ b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.generateTestData +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.readGeoParquetMetaDataMap +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.writeTestDataAsGeoParquet +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.locationtech.jts.geom.Coordinate +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.geom.GeometryFactory +import org.scalatest.prop.TableDrivenPropertyChecks + +import java.io.File +import java.nio.file.Files + +class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrivenPropertyChecks { + + val tempDir: String = + Files.createTempDirectory("sedona_geoparquet_test_").toFile.getAbsolutePath + val geoParquetDir: String = tempDir + "/geoparquet" + var df: DataFrame = _ + var geoParquetDf: DataFrame = _ + var geoParquetMetaDataMap: Map[Int, Seq[GeoParquetMetaData]] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + df = generateTestData(sparkSession) + writeTestDataAsGeoParquet(df, geoParquetDir) + geoParquetDf = sparkSession.read.format("geoparquet").load(geoParquetDir) + geoParquetMetaDataMap = readGeoParquetMetaDataMap(geoParquetDir) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) + + describe("GeoParquet spatial filter push down tests") { + it("Push down ST_Contains") { + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Contains(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Covers") { + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Covers(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Within") { + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Within(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_CoveredBy") { + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_CoveredBy(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_Intersects") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Intersects(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + + it("Push down ST_Equals") { + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((-16 -16, -16 -14, -14 -14, -14 -16, -16 -16))'))", + Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-15 -15)'))", Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-16 -16)'))", Seq(2)) + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + forAll(Table("<", "<=")) { op => + it(s"Push down ST_Distance $op d") { + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 1", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 5", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (3 4)')) $op 1", Seq(1)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 7.1", Seq(0, 1, 2, 3)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) $op 1", Seq(2)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 2", + Seq.empty) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 3", + Seq(0, 1, 2, 3)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('LINESTRING (17 17, 18 18)')) $op 1", + Seq(1)) + } + } + + it("Push down And(filters...)") { + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + } + + it("Push down Or(filters...)") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom) OR ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0, 1)) + testFilter( + "ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) <= 1 OR ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1, 2)) + } + + it("Ignore negated spatial filters") { + testFilter( + "NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) AND NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) OR NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + } + + it("Mixed spatial filter with other filter") { + testFilter( + "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + } + + /** + * Test filter push down using specified query condition, and verify if the pushed down filter + * prunes regions as expected. We'll also verify the correctness of query results. + * @param condition + * SQL query condition + * @param expectedPreservedRegions + * Regions that should be preserved after filter push down + */ + private def testFilter(condition: String, expectedPreservedRegions: Seq[Int]): Unit = { + val dfFiltered = geoParquetDf.where(condition) + val preservedRegions = getPushedDownSpatialFilter(dfFiltered) match { + case Some(spatialFilter) => resolvePreservedRegions(spatialFilter) + case None => (0 until 4) + } + assert(expectedPreservedRegions == preservedRegions) + val expectedResult = + df.where(condition).orderBy("region", "id").select("region", "id").collect() + val actualResult = dfFiltered.orderBy("region", "id").select("region", "id").collect() + assert(expectedResult sameElements actualResult) + } + + private def getPushedDownSpatialFilter(df: DataFrame): Option[GeoParquetSpatialFilter] = { + val executedPlan = df.queryExecution.executedPlan + val fileSourceScanExec = executedPlan.find(_.isInstanceOf[FileSourceScanExec]) + assert(fileSourceScanExec.isDefined) + val fileFormat = fileSourceScanExec.get.asInstanceOf[FileSourceScanExec].relation.fileFormat + assert(fileFormat.isInstanceOf[GeoParquetFileFormat]) + fileFormat.asInstanceOf[GeoParquetFileFormat].spatialFilter + } + + private def resolvePreservedRegions(spatialFilter: GeoParquetSpatialFilter): Seq[Int] = { + geoParquetMetaDataMap + .filter { case (_, metaDataList) => + metaDataList.exists(metadata => spatialFilter.evaluate(metadata.columns)) + } + .keys + .toSeq + } +} + +object GeoParquetSpatialFilterPushDownSuite { + case class TestDataItem(id: Int, region: Int, geom: Geometry) + + /** + * Generate test data centered at (0, 0). The entire dataset was divided into 4 quadrants, each + * with a unique region ID. The dataset contains 4 points and 4 polygons in each quadrant. + * @param sparkSession + * SparkSession object + * @return + * DataFrame containing test data + */ + def generateTestData(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + val regionCenters = Seq((-10, 10), (10, 10), (-10, -10), (10, -10)) + val testData = regionCenters.zipWithIndex.flatMap { case ((x, y), i) => + generateTestDataForRegion(i, x, y) + } + testData.toDF() + } + + private def generateTestDataForRegion(region: Int, centerX: Double, centerY: Double) = { + val factory = new GeometryFactory() + val points = Seq( + factory.createPoint(new Coordinate(centerX - 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX - 5, centerY - 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY - 5))) + val polygons = points.map { p => + val envelope = p.getEnvelopeInternal + envelope.expandBy(1) + factory.toGeometry(envelope) + } + (points ++ polygons).zipWithIndex.map { case (g, i) => TestDataItem(i, region, g) } + } + + /** + * Write the test dataframe as GeoParquet files. Each region is written to a separate file. + * We'll test spatial filter push down by examining which regions were preserved/pruned by + * evaluating the pushed down spatial filters + * @param testData + * dataframe containing test data + * @param path + * path to write GeoParquet files + */ + def writeTestDataAsGeoParquet(testData: DataFrame, path: String): Unit = { + testData.coalesce(1).write.partitionBy("region").format("geoparquet").save(path) + } + + /** + * Load GeoParquet metadata for each region. Note that there could be multiple files for each + * region, thus each region ID was associated with a list of GeoParquet metadata. + * @param path + * path to directory containing GeoParquet files + * @return + * Map of region ID to list of GeoParquet metadata + */ + def readGeoParquetMetaDataMap(path: String): Map[Int, Seq[GeoParquetMetaData]] = { + (0 until 4).map { k => + val geoParquetMetaDataSeq = readGeoParquetMetaDataByRegion(path, k) + k -> geoParquetMetaDataSeq + }.toMap + } + + private def readGeoParquetMetaDataByRegion( + geoParquetSavePath: String, + region: Int): Seq[GeoParquetMetaData] = { + val parquetFiles = new File(geoParquetSavePath + s"/region=$region") + .listFiles() + .filter(_.getName.endsWith(".parquet")) + parquetFiles.flatMap { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + GeoParquetMetaData.parseKeyValueMetaData(metadata) + } + } +} diff --git a/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 0000000000..2da12eceb0 --- /dev/null +++ b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.DataFrame +import org.scalatest.{BeforeAndAfterAll, FunSpec} + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } +} diff --git a/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala new file mode 100644 index 0000000000..ccfd560c84 --- /dev/null +++ b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -0,0 +1,748 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.Row +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.datasources.parquet.{Covering, GeoParquetMetaData, ParquetReadSupport} +import org.apache.spark.sql.functions.{col, expr} +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point, ST_PolygonFromEnvelope} +import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.json4s.jackson.parseJson +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKTReader +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.util.Collections +import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConverters._ + +class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation1: String = resourceFolder + "geoparquet/example1.parquet" + val geoparquetdatalocation2: String = resourceFolder + "geoparquet/example2.parquet" + val geoparquetdatalocation3: String = resourceFolder + "geoparquet/example3.parquet" + val geoparquetdatalocation4: String = resourceFolder + "geoparquet/example-1.0.0-beta.1.parquet" + val geoparquetdatalocation5: String = resourceFolder + "geoparquet/example-1.1.0.parquet" + val legacyparquetdatalocation: String = + resourceFolder + "parquet/legacy-parquet-nested-columns.snappy.parquet" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(geoparquetoutputlocation)) + + describe("GeoParquet IO tests") { + it("GEOPARQUET Test example1 i.e. naturalearth_lowers dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = df.collect()(0) + assert(rows.getAs[Long]("pop_est") == 920938) + assert(rows.getAs[String]("continent") == "Oceania") + assert(rows.getAs[String]("name") == "Fiji") + assert(rows.getAs[String]("iso_a3") == "FJI") + assert(rows.getAs[Double]("gdp_md_est") == 8374.0) + assert( + rows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample1.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample1.parquet") + val newrows = df2.collect()(0) + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + } + it("GEOPARQUET Test example2 i.e. naturalearth_citie dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation2) + val rows = df.collect()(0) + assert(rows.getAs[String]("name") == "Vatican City") + assert( + rows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample2.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample2.parquet") + val newrows = df2.collect()(0) + assert(newrows.getAs[String]("name") == "Vatican City") + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + } + it("GEOPARQUET Test example3 i.e. nybb dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation3) + val rows = df.collect()(0) + assert(rows.getAs[Long]("BoroCode") == 5) + assert(rows.getAs[String]("BoroName") == "Staten Island") + assert(rows.getAs[Double]("Shape_Leng") == 330470.010332) + assert(rows.getAs[Double]("Shape_Area") == 1.62381982381e9) + assert(rows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample3.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample3.parquet") + val newrows = df2.collect()(0) + assert( + newrows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + } + it("GEOPARQUET Test example-1.0.0-beta.1.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample4.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + + val parquetFiles = + new File(geoParquetSavePath).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + assert(columnName == "geometry") + val geomTypes = (geo \ "columns" \ "geometry" \ "geometry_types").extract[Seq[String]] + assert(geomTypes.nonEmpty) + val sparkSqlRowMetadata = metadata.get(ParquetReadSupport.SPARK_METADATA_KEY) + assert(!sparkSqlRowMetadata.contains("GeometryUDT")) + } + } + it("GEOPARQUET Test example-1.1.0.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation5) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample5.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + } + + it("GeoParquet with multiple geometry columns") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")), + Row( + 2, + wktReader.read("POINT Z(1 2 3)"), + wktReader.read("POLYGON Z((0 0 2, 1 0 2, 1 1 2, 0 1 2, 0 0 2))")), + Row( + 3, + wktReader.read("MULTIPOINT (0 0, 1 1, 2 2)"), + wktReader.read("MULTILINESTRING ((0 0, 1 1), (2 2, 3 3))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + // Find parquet files in geoParquetSavePath directory and validate their metadata + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + assert(version == GeoParquetMetaData.VERSION) + val g0Types = (geo \ "columns" \ "g0" \ "geometry_types").extract[Seq[String]] + val g1Types = (geo \ "columns" \ "g1" \ "geometry_types").extract[Seq[String]] + assert(g0Types.sorted == Seq("Point", "Point Z", "MultiPoint").sorted) + assert(g1Types.sorted == Seq("Polygon", "Polygon Z", "MultiLineString").sorted) + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == org.json4s.JNull) + } + + // Read GeoParquet with multiple geometry columns + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(df2.schema.fields(2).dataType.isInstanceOf[GeometryUDT]) + val rows = df2.collect() + assert(testData.length == rows.length) + assert(rows(0).getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(rows(0).getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + + it("GeoParquet save should work with empty dataframes") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/empty.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val g0Types = (geo \ "columns" \ "g" \ "geometry_types").extract[Seq[String]] + val g0BBox = (geo \ "columns" \ "g" \ "bbox").extract[Seq[Double]] + assert(g0Types.isEmpty) + assert(g0BBox == Seq(0.0, 0.0, 0.0, 0.0)) + } + } + + it("GeoParquet save should work with snake_case column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/snake_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geom_column") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should work with camelCase column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geomColumn", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/camel_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geomColumn") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should write user specified version and crs to geo metadata") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + // This CRS is taken from https://proj.org/en/9.3/specifications/projjson.html#geographiccrs + // with slight modification. + val projjson = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + var geoParquetSavePath = geoparquetoutputlocation + "/gp_custom_meta.parquet" + df.write + .format("geoparquet") + .option("geoparquet.version", "10.9.8") + .option("geoparquet.crs", projjson) + .mode("overwrite") + .save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + val columnName = (geo \ "primary_column").extract[String] + assert(version == "10.9.8") + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs.isInstanceOf[org.json4s.JObject]) + assert(crs == parseJson(projjson)) + } + + // Setting crs to null explicitly + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val df3 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df3.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNull) + } + + // Setting crs to "" to omit crs + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNothing) + } + } + + it("GeoParquet save should support specifying per-column CRS") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + + val projjson0 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + + val projjson1 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "Monte Mario (Rome)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "Monte Mario (Rome)", + | "ellipsoid": { + | "name": "International 1924", + | "semi_major_axis": 6378388, + | "inverse_flattening": 297 + | }, + | "prime_meridian": { + | "name": "Rome", + | "longitude": 12.4523333333333 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Geodesy, onshore minerals management.", + | "area": "Italy - onshore and offshore; San Marino, Vatican City State.", + | "bbox": { + | "south_latitude": 34.76, + | "west_longitude": 5.93, + | "north_latitude": 47.1, + | "east_longitude": 18.99 + | }, + | "id": { + | "authority": "EPSG", + | "code": 4806 + | } + |} + |""".stripMargin + + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms_with_custom_crs.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == parseJson(projjson1)) + } + + // Write without fallback CRS for g0 + df.write + .format("geoparquet") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == parseJson(projjson1)) + } + + // Fallback CRS is omitting CRS + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNothing) + assert(g1Crs == parseJson(projjson1)) + } + + // Write with CRS, explicitly set CRS to null for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "null") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNull) + } + + // Write with CRS, explicitly omit CRS for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNothing) + } + } + + it("GeoParquet load should raise exception when loading plain parquet files") { + val e = intercept[SparkException] { + sparkSession.read.format("geoparquet").load(resourceFolder + "geoparquet/plain.parquet") + } + assert(e.getMessage.contains("does not contain valid geo metadata")) + } + + it("GeoParquet load with spatial predicates") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = + df.where(ST_Intersects(ST_Point(35.174722, -6.552465), col("geometry"))).collect() + assert(rows.length == 1) + assert(rows(0).getAs[String]("name") == "Tanzania") + } + + it("Filter push down for nested columns") { + import sparkSession.implicits._ + + // Prepare multiple GeoParquet files with bbox metadata. There should be 10 files in total, each file contains + // 1000 records. + val dfIds = (0 until 10000).toDF("id") + val dfGeom = dfIds + .withColumn( + "bbox", + expr("struct(id as minx, id as miny, id + 1 as maxx, id + 1 as maxy)")) + .withColumn("geom", expr("ST_PolygonFromEnvelope(id, id, id + 1, id + 1)")) + .withColumn("part_id", expr("CAST(id / 1000 AS INTEGER)")) + .coalesce(1) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_bbox.parquet" + dfGeom.write + .partitionBy("part_id") + .format("geoparquet") + .mode("overwrite") + .save(geoParquetSavePath) + + val sparkListener = new SparkListener() { + val recordsRead = new AtomicLong(0) + + def reset(): Unit = recordsRead.set(0) + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead + this.recordsRead.getAndAdd(recordsRead) + } + } + + sparkSession.sparkContext.addSparkListener(sparkListener) + try { + val df = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + + // This should trigger filter push down to Parquet and only read one of the files. The number of records read + // should be less than 1000. + df.where("bbox.minx > 6000 and bbox.minx < 6600").count() + assert(sparkListener.recordsRead.get() <= 1000) + + // Reading these files using spatial filter. This should only read two of the files. + sparkListener.reset() + df.where(ST_Intersects(ST_PolygonFromEnvelope(7010, 7010, 8100, 8100), col("geom"))) + .count() + assert(sparkListener.recordsRead.get() <= 2000) + } finally { + sparkSession.sparkContext.removeSparkListener(sparkListener) + } + } + + it("Ready legacy parquet files written by Apache Sedona <= 1.3.1-incubating") { + val df = sparkSession.read + .format("geoparquet") + .option("legacyMode", "true") + .load(legacyparquetdatalocation) + val rows = df.collect() + assert(rows.nonEmpty) + rows.foreach { row => + assert(row.getAs[AnyRef]("geom").isInstanceOf[Geometry]) + assert(row.getAs[AnyRef]("struct_geom").isInstanceOf[Row]) + val structGeom = row.getAs[Row]("struct_geom") + assert(structGeom.getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(structGeom.getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + } + + it("GeoParquet supports writing covering metadata") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geometry", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geometry", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + } + + it("GeoParquet supports writing covering metadata for multiple columns") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geom1", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov1", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + .withColumn("geom2", expr("ST_Point(10 * id, 10 * id + 1)")) + .withColumn( + "test_cov2", + expr( + "struct(10 * id AS xmin, 10 * id + 1 AS ymin, 10 * id AS xmax, 10 * id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering.geom1", "test_cov1") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + Seq(("geom1", "test_cov1"), ("geom2", "test_cov2")).foreach { + case (geomName, coveringName) => + val coveringJsValue = geo \ "columns" \ geomName \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq(coveringName, "xmin")) + assert(covering.bbox.ymin == Seq(coveringName, "ymin")) + assert(covering.bbox.xmax == Seq(coveringName, "xmax")) + assert(covering.bbox.ymax == Seq(coveringName, "ymax")) + } + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + assert(geo \ "columns" \ "geom1" \ "covering" == org.json4s.JNothing) + val coveringJsValue = geo \ "columns" \ "geom2" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov2", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov2", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov2", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov2", "ymax")) + } + } + } + + def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = { + val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + body(geo) + } + } +} diff --git a/spark/spark-3.2/.gitignore b/spark/spark-3.2/.gitignore new file mode 100644 index 0000000000..1cc6c4a1f6 --- /dev/null +++ b/spark/spark-3.2/.gitignore @@ -0,0 +1,12 @@ +/target/ +/.settings/ +/.classpath +/.project +/dependency-reduced-pom.xml +/doc/ +/.idea/ +*.iml +/latest/ +/spark-warehouse/ +/metastore_db/ +*.log diff --git a/spark/spark-3.2/pom.xml b/spark/spark-3.2/pom.xml new file mode 100644 index 0000000000..76c11238d0 --- /dev/null +++ b/spark/spark-3.2/pom.xml @@ -0,0 +1,145 @@ + + + + 4.0.0 + + org.apache.sedona + sedona-spark-parent-${spark.compat.version}_${scala.compat.version} + 1.6.1-SNAPSHOT + ../pom.xml + + sedona-spark-3.2_${scala.compat.version} + + ${project.groupId}:${project.artifactId} + A cluster computing system for processing large-scale spatial data: SQL API for Spark 3.2. + http://sedona.apache.org/ + jar + + + false + + + + + org.apache.sedona + sedona-common + ${project.version} + + + com.fasterxml.jackson.core + * + + + + + org.apache.sedona + sedona-spark-common-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + + + org.apache.hadoop + hadoop-client + + + org.apache.logging.log4j + log4j-1.2-api + + + org.geotools + gt-main + + + org.geotools + gt-referencing + + + org.geotools + gt-epsg-hsql + + + org.geotools + gt-geotiff + + + org.geotools + gt-coverage + + + org.geotools + gt-arcgrid + + + org.locationtech.jts + jts-core + + + org.wololo + jts2geojson + + + com.fasterxml.jackson.core + * + + + + + org.scala-lang + scala-library + + + org.scala-lang.modules + scala-collection-compat_${scala.compat.version} + + + org.scalatest + scalatest_${scala.compat.version} + + + org.mockito + mockito-inline + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..e5f994e203 --- /dev/null +++ b/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,2 @@ +org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala new file mode 100644 index 0000000000..4348325570 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.RebaseDateTime +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.util.Utils + +import scala.util.Try + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDataSourceUtils { + + val PARQUET_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + val PARQUET_INT96_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_INT96_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + + private def firstAvailableConf(confs: String*): String = { + confs.find(c => Try(SQLConf.get.getConfString(c)).isSuccess).get + } + + def datetimeRebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def int96RebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def creteDateRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteDateRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteTimestampRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + + def creteTimestampRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala new file mode 100644 index 0000000000..bf3c2a19a9 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDateTimeUtils { + + /** + * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have + * microseconds precision, so this conversion is lossy. + */ + def microsToMillis(micros: Long): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floorDiv(micros, MICROS_PER_MILLIS) + } + + /** + * Converts milliseconds since the epoch to microseconds. + */ + def millisToMicros(millis: Long): Long = { + Math.multiplyExact(millis, MICROS_PER_MILLIS) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala new file mode 100644 index 0000000000..702c6f31fb --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.readParquetFootersInParallel +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +import java.net.URI +import scala.collection.JavaConverters._ +import scala.util.Failure +import scala.util.Try + +class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) + extends ParquetFileFormat + with GeoParquetFileFormatBase + with FileFormat + with DataSourceRegister + with Logging + with Serializable { + + def this() = this(None) + + override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat] && + other.asInstanceOf[GeoParquetFileFormat].spatialFilter == spatialFilter + + override def hashCode(): Int = getClass.hashCode() + + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = + new GeoParquetFileFormat(Some(spatialFilter)) + + override def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + GeoParquetUtils.inferSchema(sparkSession, parameters, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) + + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[OutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo( + "Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo( + "Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[OutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // This metadata is useful for keeping UDTs like Vector/Matrix. + ParquetWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + + conf.set( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + + try { + val fieldIdWriteEnabled = + SQLConf.get.getConfString("spark.sql.parquet.fieldId.write.enabled") + conf.set("spark.sql.parquet.fieldId.write.enabled", fieldIdWriteEnabled) + } catch { + case e: NoSuchElementException => () + } + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) + + // SPARK-15719: Disables writing Parquet summary files by default. + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) + } + + if (ParquetOutputFormat.getJobSummaryLevel(conf) != JobSummaryLevel.NONE + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning( + s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") + } + + conf.set(ParquetOutputFormat.WRITE_SUPPORT_CLASS, classOf[GeoParquetWriteSupport].getName) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } + } + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, requiredSchema.json) + hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader: Boolean = + sqlConf.parquetVectorizedReaderEnabled && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + assert(file.partitionValues.numFields == partitionSchema.size) + + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + + val sharedConf = broadcastedHadoopConf.value.value + + val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new GeoParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(_)) + .reduceOption(FilterApi.and) + } else { + None + } + + // Prune file scans using pushed down spatial filters and per-column bboxes in geoparquet metadata + val shouldScanFile = + GeoParquetMetaData.parseKeyValueMetaData(footerFileMetaData.getKeyValueMetaData).forall { + metadata => spatialFilter.forall(_.evaluate(metadata.columns)) + } + if (!shouldScanFile) { + // The entire file is pruned so that we don't need to scan this file. + Seq.empty[InternalRow].iterator + } else { + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_READ)) + val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_READ)) + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = + new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + logWarning( + s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr") + } + logDebug(s"Falling back to parquet-mr") + // ParquetRecordReader returns InternalRow + val readSupport = new GeoParquetReadSupport( + convertTz, + enableVectorizedReader = false, + datetimeRebaseMode, + int96RebaseMode, + options) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val iter = new RecordReaderIterator[InternalRow](reader) + // SPARK-23457 Register a task completion listener before `initialization`. + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + reader.initialize(split, hadoopAttemptContext) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } + } + } + + override def supportDataType(dataType: DataType): Boolean = super.supportDataType(dataType) + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = false +} + +object GeoParquetFileFormat extends Logging { + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block + * of that file. Thus we only need to retrieve the location of the last block. However, + * Hadoop `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on S3 + * nodes). + */ + def mergeSchemasInParallel( + parameters: Map[String, String], + filesToTouch: Seq[FileStatus], + sparkSession: SparkSession): Option[StructType] = { + val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + + val reader = (files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) => { + readParquetFootersInParallel(conf, files, ignoreCorruptFiles) + .map { footer => + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val keyValueMetaData = footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData + val converter = new GeoParquetToSparkSchemaConverter( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + parameters = parameters) + readSchemaFromFooter(footer, keyValueMetaData, converter, parameters) + } + } + + GeoSchemaMergeUtils.mergeSchemasInParallel(sparkSession, parameters, filesToTouch, reader) + } + + private def readSchemaFromFooter( + footer: Footer, + keyValueMetaData: java.util.Map[String, String], + converter: GeoParquetToSparkSchemaConverter, + parameters: Map[String, String]): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData.getKeyValueMetaData.asScala.toMap + .get(ParquetReadSupport.SPARK_METADATA_KEY) + .flatMap(schema => deserializeSchemaString(schema, keyValueMetaData, parameters)) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString( + schemaString: String, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + val schemaOpt = Try(DataType.fromJson(schemaString).asInstanceOf[StructType]) + .recover { case _: Throwable => + logInfo( + "Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + LegacyTypeStringParser.parseString(schemaString).asInstanceOf[StructType] + } + .recoverWith { case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", + cause) + Failure(cause) + } + .toOption + + schemaOpt.map(schema => + replaceGeometryColumnWithGeometryUDT(schema, keyValueMetaData, parameters)) + } + + private def replaceGeometryColumnWithGeometryUDT( + schema: StructType, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): StructType = { + val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + val fields = schema.fields.map { field => + field.dataType match { + case _: BinaryType if geoParquetMetaData.columns.contains(field.name) => + field.copy(dataType = GeometryUDT) + case _ => field + } + } + StructType(fields) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala new file mode 100644 index 0000000000..d44f679058 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ + +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +// Needed by Sedona to support Spark 3.0 - 3.3 +/** + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class GeoParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + case p: PrimitiveType => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getOriginalType, + p.getPrimitiveTypeName, + p.getTypeLength, + p.getDecimalMetadata))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + (field.fieldNames.toSeq.quoted, field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) + + private def dateToDays(date: Any): Int = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + + private def timestampToMicros(v: Any): JLong = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = GeoDateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct + .flatMap { v => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala new file mode 100644 index 0000000000..a3c2be5d22 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +import java.time.ZoneId +import java.util.{Locale, Map => JMap} +import scala.collection.JavaConverters._ + +/** + * A Parquet [[ReadSupport]] implementation for reading Parquet records as Catalyst + * [[InternalRow]]s. + * + * The API interface of [[ReadSupport]] is a little bit over complicated because of historical + * reasons. In older versions of parquet-mr (say 1.6.0rc3 and prior), [[ReadSupport]] need to be + * instantiated and initialized twice on both driver side and executor side. The [[init()]] method + * is for driver side initialization, while [[prepareForRead()]] is for executor side. However, + * starting from parquet-mr 1.6.0, it's no longer the case, and [[ReadSupport]] is only + * instantiated and initialized on executor side. So, theoretically, now it's totally fine to + * combine these two methods into a single initialization method. The only reason (I could think + * of) to still have them here is for parquet-mr API backwards-compatibility. + * + * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from + * [[init()]] to [[prepareForRead()]], but use a private `var` for simplicity. + */ +class GeoParquetReadSupport( + override val convertTz: Option[ZoneId], + enableVectorizedReader: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends ParquetReadSupport + with Logging { + private var catalystRequestedSchema: StructType = _ + + /** + * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record + * readers. Responsible for figuring out Parquet requested schema used for column pruning. + */ + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + catalystRequestedSchema = { + val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + assert(schemaString != null, "Parquet requested schema not set.") + StructType.fromString(schemaString) + } + + val caseSensitive = + conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) + val schemaPruningEnabled = conf.getBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) + val parquetFileSchema = context.getFileSchema + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema( + parquetFileSchema, + catalystRequestedSchema, + caseSensitive) + + // We pass two schema to ParquetRecordMaterializer: + // - parquetRequestedSchema: the schema of the file data we want to read + // - catalystRequestedSchema: the schema of the rows we want to return + // The reader is responsible for reconciling the differences between the two. + val parquetRequestedSchema = if (schemaPruningEnabled && !enableVectorizedReader) { + // Parquet-MR reader requires that parquetRequestedSchema include only those fields present + // in the underlying parquetFileSchema. Therefore, we intersect the parquetClippedSchema + // with the parquetFileSchema + GeoParquetReadSupport + .intersectParquetGroups(parquetClippedSchema, parquetFileSchema) + .map(groupType => new MessageType(groupType.getName, groupType.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + // Spark's vectorized reader only support atomic types currently. It also skip fields + // in parquetRequestedSchema which are not present in the file. + parquetClippedSchema + } + logDebug( + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$parquetFileSchema + |Parquet clipped schema: + |$parquetClippedSchema + |Parquet requested schema: + |$parquetRequestedSchema + |Catalyst requested schema: + |${catalystRequestedSchema.treeString} + """.stripMargin) + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new GeoParquetRecordMaterializer( + parquetRequestedSchema, + GeoParquetReadSupport.expandUDT(catalystRequestedSchema), + new GeoParquetToSparkSchemaConverter(keyValueMetaData, conf, parameters), + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters) + } +} + +object GeoParquetReadSupport extends Logging { + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist in + * `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = + clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema, caseSensitive) + if (clippedParquetFields.isEmpty) { + ParquetSchemaConverter.EMPTY_MESSAGE + } else { + Types + .buildMessage() + .addFields(clippedParquetFields: _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + } + + private def clipParquetType( + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + + case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. + */ + private def clipParquetListType( + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType, caseSensitive) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList + .getType(0) + .isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if (repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple") { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], + * or a [[StructType]]. + */ + private def clipParquetMapType( + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A clipped [[GroupType]], which has at least one field. + * @note + * Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): Seq[Type] = { + val toParquet = new SparkToGeoParquetSchemaConverter(writeLegacyParquetFormat = false) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException( + s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + } + .getOrElse(toParquet.convertField(f)) + } + } + } + + /** + * Computes the structural intersection between two Parquet group types. This is used to create + * a requestedSchema for ReadContext of Parquet-MR reader. Parquet-MR reader does not support + * the nested field access to non-existent field while parquet library does support to read the + * non-existent field by regular field access. + */ + private def intersectParquetGroups( + groupType1: GroupType, + groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + val field2 = groupType2.getType(field1.getName) + if (field2.isPrimitive) { + None + } else { + intersectParquetGroups(field1, field2.asGroupType) + } + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + + def expandUDT(schema: StructType): StructType = { + def expand(dataType: DataType): DataType = { + dataType match { + case t: ArrayType => + t.copy(elementType = expand(t.elementType)) + + case t: MapType => + t.copy(keyType = expand(t.keyType), valueType = expand(t.valueType)) + + case t: StructType => + val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType))) + t.copy(fields = expandedFields) + + // Don't expand GeometryUDT types. We'll treat geometry columns specially in + // GeoParquetRowConverter + case t: GeometryUDT => t + + case t: UserDefinedType[_] => + t.sqlType + + case t => + t + } + } + + expand(schema).asInstanceOf[StructType] + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala new file mode 100644 index 0000000000..dedbb237b5 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.time.ZoneId +import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.types.StructType + +/** + * A [[RecordMaterializer]] for Catalyst rows. + * + * @param parquetSchema + * Parquet schema of the records to be read + * @param catalystSchema + * Catalyst schema of the rows to be constructed + * @param schemaConverter + * A Parquet-Catalyst schema converter that helps initializing row converters + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseSpec + * the specification of rebasing date/timestamp from Julian to Proleptic Gregorian calendar: + * mode + optional original time zone + * @param int96RebaseSpec + * the specification of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + */ +class GeoParquetRecordMaterializer( + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: GeoParquetToSparkSchemaConverter, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends RecordMaterializer[InternalRow] { + private val rootConverter = new GeoParquetRowConverter( + schemaConverter, + parquetSchema, + catalystSchema, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + NoopUpdater) + + override def getCurrentRecord: InternalRow = rootConverter.currentRecord + + override def getRootConverter: GroupConverter = rootConverter +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala new file mode 100644 index 0000000000..2f2eea38cd --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala @@ -0,0 +1,745 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.{GroupType, OriginalType, Type} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.locationtech.jts.io.WKBReader + +import java.math.{BigDecimal, BigInteger} +import java.time.{ZoneId, ZoneOffset} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +/** + * A [[ParquetRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root + * converter. Take the following Parquet type as an example: + * {{{ + * message root { + * required int32 f1; + * optional group f2 { + * required double f21; + * optional binary f22 (utf8); + * } + * } + * }}} + * 5 converters will be created: + * + * - a root [[ParquetRowConverter]] for [[org.apache.parquet.schema.MessageType]] `root`, which + * contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.OriginalType.INT_32]] field `f1`, and + * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE]] field `f21`, and + * - a [[ParquetStringConverter]] for optional + * [[org.apache.parquet.schema.OriginalType.UTF8]] string field `f22` + * + * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have + * any "parent" container. + * + * @param schemaConverter + * A utility converter used to convert Parquet types to Catalyst types. + * @param parquetType + * Parquet schema of Parquet records + * @param catalystType + * Spark SQL schema that corresponds to the Parquet record type. User-defined types other than + * [[GeometryUDT]] should have been expanded. + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseMode + * the mode of rebasing date/timestamp from Julian to Proleptic Gregorian calendar + * @param int96RebaseMode + * the mode of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + * @param updater + * An updater which propagates converted field values to the parent container + */ +private[parquet] class GeoParquetRowConverter( + schemaConverter: GeoParquetToSparkSchemaConverter, + parquetType: GroupType, + catalystType: StructType, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String], + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) + with Logging { + + assert( + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + + assert( + !catalystType.existsRecursively(t => + !t.isInstanceOf[GeometryUDT] && t.isInstanceOf[UserDefinedType[_]]), + s"""User-defined types in Catalyst schema should have already been expanded: + |${catalystType.prettyJson} + """.stripMargin) + + logDebug(s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) + + /** + * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates + * converted filed values to the `ordinal`-th cell in `currentRow`. + */ + private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater { + override def set(value: Any): Unit = row(ordinal) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(value: Short): Unit = row.setShort(ordinal, value) + override def setInt(value: Int): Unit = row.setInt(ordinal, value) + override def setLong(value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) + } + + private[this] val currentRow = new SpecificInternalRow(catalystType.map(_.dataType)) + + /** + * The [[InternalRow]] converted from an entire Parquet record. + */ + def currentRecord: InternalRow = currentRow + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(int96RebaseMode, "Parquet INT96") + + // Converters for each field. + private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { + // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false + // to prevent throwing IllegalArgumentException when searching catalyst type's field index + val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { + catalystType.fieldNames.zipWithIndex.toMap + } else { + CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + } + parquetType.getFields.asScala.map { parquetField => + val fieldIndex = catalystFieldNameToIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + }.toArray + } + + // Updaters for each field. + private[this] val fieldUpdaters: Array[ParentContainerUpdater] = fieldConverters.map(_.updater) + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def end(): Unit = { + var i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).end() + i += 1 + } + updater.set(currentRow) + } + + override def start(): Unit = { + var i = 0 + val numFields = currentRow.numFields + while (i < numFields) { + currentRow.setNullAt(i) + i += 1 + } + i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).start() + i += 1 + } + } + + /** + * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type + * `catalystType`. Converted values are handled by `updater`. + */ + private def newConverter( + parquetType: Type, + catalystType: DataType, + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { + + catalystType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => + new ParquetPrimitiveConverter(updater) + + case GeometryUDT => + if (parquetType.isPrimitive) { + new ParquetPrimitiveConverter(updater) { + override def addBinary(value: Binary): Unit = { + val wkbReader = new WKBReader() + val geom = wkbReader.read(value.getBytes) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + if (GeoParquetUtils.isLegacyMode(parameters)) { + new ParquetArrayConverter( + parquetType.asGroupType(), + ArrayType(ByteType, containsNull = false), + updater) { + override def end(): Unit = { + val wkbReader = new WKBReader() + val byteArray = currentArray.map(_.asInstanceOf[Byte]).toArray + val geom = wkbReader.read(byteArray) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + throw new IllegalArgumentException( + s"Parquet type for geometry column is $parquetType. This parquet file could be written by " + + "Apache Sedona <= 1.3.1-incubating. Please use option(\"legacyMode\", \"true\") to read this file.") + } + } + + case ByteType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setByte(value.asInstanceOf[ByteType#InternalType]) + + override def addBinary(value: Binary): Unit = { + val bytes = value.getBytes + for (b <- bytes) { + updater.set(b) + } + } + } + + case ShortType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setShort(value.asInstanceOf[ShortType#InternalType]) + } + + // For INT32 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For INT64 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals + case t: DecimalType + if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || + parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => + new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + case t: DecimalType => + throw new RuntimeException( + s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " + + s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " + + "FIXED_LEN_BYTE_ARRAY, or BINARY.") + + case StringType => + new ParquetStringConverter(updater) + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(timestampRebaseFunc(value)) + } + } + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + val micros = GeoDateTimeUtils.millisToMicros(value) + updater.setLong(timestampRebaseFunc(micros)) + } + } + + // INT96 timestamp doesn't have a logical type, here we check the physical type instead. + case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 => + new ParquetPrimitiveConverter(updater) { + // Converts nanosecond timestamps stored as INT96 + override def addBinary(value: Binary): Unit = { + val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = convertTz + .map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + updater.setLong(adjTime) + } + } + + case DateType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + updater.set(dateRebaseFunc(value)) + } + } + + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + + case t: ArrayType => + new ParquetArrayConverter(parquetType.asGroupType(), t, updater) + + case t: MapType => + new ParquetMapConverter(parquetType.asGroupType(), t, updater) + + case t: StructType => + val wrappedUpdater = { + // SPARK-30338: avoid unnecessary InternalRow copying for nested structs: + // There are two cases to handle here: + // + // 1. Parent container is a map or array: we must make a deep copy of the mutable row + // because this converter may be invoked multiple times per Parquet input record + // (if the map or array contains multiple elements). + // + // 2. Parent container is a struct: we don't need to copy the row here because either: + // + // (a) all ancestors are structs and therefore no copying is required because this + // converter will only be invoked once per Parquet input record, or + // (b) some ancestor is struct that is nested in a map or array and that ancestor's + // converter will perform deep-copying (which will recursively copy this row). + if (updater.isInstanceOf[RowUpdater]) { + // `updater` is a RowUpdater, implying that the parent container is a struct. + updater + } else { + // `updater` is NOT a RowUpdater, implying that the parent container a map or array. + new ParentContainerUpdater { + override def set(value: Any): Unit = { + updater.set(value.asInstanceOf[SpecificInternalRow].copy()) // deep copy + } + } + } + } + new GeoParquetRowConverter( + schemaConverter, + parquetType.asGroupType(), + t, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + wrappedUpdater) + + case t => + throw new RuntimeException( + s"Unable to create Parquet converter for data type ${t.json} " + + s"whose Parquet type is $parquetType") + } + } + + /** + * Parquet converter for strings. A dictionary is used to minimize string decoding cost. + */ + private final class ParquetStringConverter(updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + private var expandedDictionary: Array[UTF8String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i => + UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes) + } + } + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + override def addBinary(value: Binary): Unit = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we + // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying + // it. + val buffer = value.toByteBuffer + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() + updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) + } + } + + /** + * Parquet converter for fixed-precision decimals. + */ + private abstract class ParquetDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + protected var expandedDictionary: Array[Decimal] = _ + + override def hasDictionarySupport: Boolean = true + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + // Converts decimals stored as INT32 + override def addInt(value: Int): Unit = { + addLong(value: Long) + } + + // Converts decimals stored as INT64 + override def addLong(value: Long): Unit = { + updater.set(decimalFromLong(value)) + } + + // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY + override def addBinary(value: Binary): Unit = { + updater.set(decimalFromBinary(value)) + } + + protected def decimalFromLong(value: Long): Decimal = { + Decimal(value, precision, scale) + } + + protected def decimalFromBinary(value: Binary): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + val unscaled = ParquetRowConverter.binaryToUnscaledLong(value) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) + } + } + } + + private class ParquetIntDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToInt(id).toLong) + } + } + } + + private class ParquetLongDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToLong(id)) + } + } + } + + private class ParquetBinaryDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromBinary(dictionary.decodeToBinary(id)) + } + } + } + + /** + * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard + * Parquet lists are represented as a 3-level group annotated by `LIST`: + * {{{ + * group (LIST) { <-- parquetSchema points here + * repeated group list { + * element; + * } + * } + * }}} + * The `parquetSchema` constructor argument points to the outermost group. + * + * However, before this representation is standardized, some Parquet libraries/tools also use + * some non-standard formats to represent list-like structures. Backwards-compatibility rules + * for handling these cases are described in Parquet format spec. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + */ + private class ParquetArrayConverter( + parquetSchema: GroupType, + catalystSchema: ArrayType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + protected[this] val currentArray: mutable.ArrayBuffer[Any] = ArrayBuffer.empty[Any] + + private[this] val elementConverter: Converter = { + val repeatedType = parquetSchema.getType(0) + val elementType = catalystSchema.elementType + + // At this stage, we're not sure whether the repeated field maps to the element type or is + // just the syntactic repeated group of the 3-level standard LIST layout. Take the following + // Parquet LIST-annotated group type as an example: + // + // optional group f (LIST) { + // repeated group list { + // optional group element { + // optional int32 element; + // } + // } + // } + // + // This type is ambiguous: + // + // 1. When interpreted as a standard 3-level layout, the `list` field is just the syntactic + // group, and the entire type should be translated to: + // + // ARRAY> + // + // 2. On the other hand, when interpreted as a non-standard 2-level layout, the `list` field + // represents the element type, and the entire type should be translated to: + // + // ARRAY>> + // + // Here we try to convert field `list` into a Catalyst type to see whether the converted type + // matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise, + // it's case 2. + val guessedElementType = schemaConverter.convertFieldWithGeo(repeatedType) + + if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) { + // If the repeated field corresponds to the element type, creates a new converter using the + // type of the repeated field. + newConverter( + repeatedType, + elementType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentArray += value + }) + } else { + // If the repeated field corresponds to the syntactic group in the standard 3-level Parquet + // LIST layout, creates a new converter using the only child field of the repeated field. + assert(!repeatedType.isPrimitive && repeatedType.asGroupType().getFieldCount == 1) + new ElementConverter(repeatedType.asGroupType().getType(0), elementType) + } + } + + override def getConverter(fieldIndex: Int): Converter = elementConverter + + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + + override def start(): Unit = currentArray.clear() + + /** Array element converter */ + private final class ElementConverter(parquetType: Type, catalystType: DataType) + extends GroupConverter { + + private var currentElement: Any = _ + + private[this] val converter = + newConverter( + parquetType, + catalystType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentElement = value + }) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = currentArray += currentElement + + override def start(): Unit = currentElement = null + } + } + + /** Parquet converter for maps */ + private final class ParquetMapConverter( + parquetType: GroupType, + catalystType: MapType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + private[this] val currentKeys = ArrayBuffer.empty[Any] + private[this] val currentValues = ArrayBuffer.empty[Any] + + private[this] val keyValueConverter = { + val repeatedType = parquetType.getType(0).asGroupType() + new KeyValueConverter( + repeatedType.getType(0), + repeatedType.getType(1), + catalystType.keyType, + catalystType.valueType) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override def end(): Unit = { + // The parquet map may contains null or duplicated map keys. When it happens, the behavior is + // undefined. + // TODO (SPARK-26174): disallow it with a config. + updater.set( + new ArrayBasedMapData( + new GenericArrayData(currentKeys.toArray), + new GenericArrayData(currentValues.toArray))) + } + + override def start(): Unit = { + currentKeys.clear() + currentValues.clear() + } + + /** Parquet converter for key-value pairs within the map. */ + private final class KeyValueConverter( + parquetKeyType: Type, + parquetValueType: Type, + catalystKeyType: DataType, + catalystValueType: DataType) + extends GroupConverter { + + private var currentKey: Any = _ + + private var currentValue: Any = _ + + private[this] val converters = Array( + // Converter for keys + newConverter( + parquetKeyType, + catalystKeyType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentKey = value + }), + + // Converter for values + newConverter( + parquetValueType, + catalystValueType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = { + currentKeys += currentKey + currentValues += currentValue + } + + override def start(): Unit = { + currentKey = null + currentValue = null + } + } + } + + private trait RepeatedConverter { + private[this] val currentArray = ArrayBuffer.empty[Any] + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray.clear() + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter + with RepeatedConverter + with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = + elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter + with HasParentContainerUpdater + with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala new file mode 100644 index 0000000000..eab20875a6 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.checkConversionRequirement +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @param assumeBinaryIsString + * Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields. + * @param assumeInt96IsTimestamp + * Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields. + * @param parameters + * Options for reading GeoParquet files. + */ +class GeoParquetToSparkSchemaConverter( + keyValueMetaData: java.util.Map[String, String], + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + parameters: Map[String, String]) { + + private val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: SQLConf, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + parameters = parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: Configuration, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + parameters = parameters) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.asScala.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertFieldWithGeo(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertFieldWithGeo(field), nullable = false) + + case REPEATED => + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertFieldWithGeo(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) + } + } + + StructType(fields.toSeq) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertFieldWithGeo(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def isGeometryField(fieldName: String): Boolean = + geoParquetMetaData.columns.contains(fieldName) + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotSupported() = + throw new IllegalArgumentException(s"Parquet type not supported: $typeString") + + def typeNotImplemented() = + throw new IllegalArgumentException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new IllegalArgumentException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + ParquetSchemaConverter.checkConversionRequirement( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + originalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + originalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case UINT_64 => typeNotSupported() + case TIMESTAMP_MICROS => TimestampType + case TIMESTAMP_MILLIS => TimestampType + case _ => illegalType() + } + + case INT96 => + ParquetSchemaConverter.checkConversionRequirement( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + originalType match { + case UTF8 | ENUM | JSON => StringType + case null if isGeometryField(field.getName) => GeometryUDT + case null if assumeBinaryIsString => StringType + case null => BinaryType + case BSON => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + originalType match { + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1, + s"Invalid list type $field") + + val repeatedType = field.getType(0) + ParquetSchemaConverter.checkConversionRequirement( + repeatedType.isRepetition(REPEATED), + s"Invalid list type $field") + + if (isElementTypeWithGeo(repeatedType, field.getName)) { + ArrayType(convertFieldWithGeo(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertFieldWithGeo(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + ParquetSchemaConverter.checkConversionRequirement( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertFieldWithGeo(keyType), + convertFieldWithGeo(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new IllegalArgumentException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + def isElementTypeWithGeo(repeatedType: Type, parentName: String): Boolean = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // ARRAY (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } +} + +/** + * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]]. + * + * @param writeLegacyParquetFormat + * Whether to use legacy Parquet format compatible with Spark 1.4 and prior versions when + * converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. When set to false, use + * standard format defined in parquet-format spec. This argument only affects Parquet write + * path. + * @param outputTimestampType + * which parquet timestamp type to use when writing. + */ +class SparkToGeoParquetSchemaConverter( + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96) + extends SparkToParquetSchemaConverter(writeLegacyParquetFormat, outputTimestampType) { + + def this(conf: SQLConf) = this( + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + outputTimestampType = conf.parquetOutputTimestampType) + + def this(conf: Configuration) = this( + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, + outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + override def convert(catalystSchema: StructType): MessageType = { + Types + .buildMessage() + .addFields(catalystSchema.map(convertField): _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + override def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + GeoParquetSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or + // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the + // behavior same as before. + // + // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond + // timestamp in Impala for some historical reasons. It's not recommended to be used for any + // other types and will probably be deprecated in some future version of parquet-format spec. + // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and + // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. + // + // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting + // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can + // store a timestamp into a `Long`. This design decision is subject to change though, for + // example, we may resort to nanosecond precision in the future. + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + Types.primitive(INT96, repetition).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + } + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ====================== + // Decimals (legacy mode) + // ====================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // ======================== + // Decimals (standard mode) + // ======================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // =================================== + // ArrayType and MapType (legacy mode) + // =================================== + + // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level + // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro + // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element + // field name "array" is borrowed from parquet-avro. + case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => + // group (LIST) { + // optional group bag { + // repeated array; + // } + // } + + // This should not use `listOfElements` here because this new method checks if the + // element name is `element` in the `GroupType` and throws an exception if not. + // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with + // `array` as its element name as below. Therefore, we build manually + // the correct group type here via the builder. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .buildGroup(REPEATED) + // "array" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable))) + .named("bag")) + .named(field.name) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => + // group (LIST) { + // repeated element; + // } + + // Here too, we should not use `listOfElements`. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable), REPEATED)) + .named(field.name) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ===================================== + // ArrayType and MapType (standard mode) + // ===================================== + + case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition) + .as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields + .foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + } + .named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new IllegalArgumentException( + s"Unsupported data type ${field.dataType.catalogString}") + } + } +} + +private[sql] object GeoParquetSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + checkConversionRequirement( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ").trim) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala new file mode 100644 index 0000000000..477d744441 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +import scala.language.existentials + +object GeoParquetUtils { + def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val parquetOptions = new ParquetOptions(parameters, sparkSession.sessionState.conf) + val shouldMergeSchemas = parquetOptions.mergeSchema + val mergeRespectSummaries = sparkSession.sessionState.conf.isParquetSchemaRespectSummaries + val filesByType = splitFiles(files) + val filesToTouch = + if (shouldMergeSchemas) { + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq.empty + } else { + filesByType.data + } + needMerged ++ filesByType.metadata ++ filesByType.commonMetadata + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + filesByType.commonMetadata.headOption + // Falls back to "_metadata" + .orElse(filesByType.metadata.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(filesByType.data.headOption) + .toSeq + } + GeoParquetFileFormat.mergeSchemasInParallel(parameters, filesToTouch, sparkSession) + } + + case class FileTypes( + data: Seq[FileStatus], + metadata: Seq[FileStatus], + commonMetadata: Seq[FileStatus]) + + private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = { + val leaves = allFiles.toArray.sortBy(_.getPath.toString) + + FileTypes( + data = leaves.filterNot(f => isSummaryFile(f.getPath)), + metadata = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE), + commonMetadata = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)) + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + /** + * Legacy mode option is for reading Parquet files written by old versions of Apache Sedona (<= + * 1.3.1-incubating). Such files are actually not GeoParquet files and do not have GeoParquet + * file metadata. Geometry fields were encoded as list of bytes and stored as group type in + * Parquet files. The Definition of GeometryUDT before 1.4.0 was: + * {{{ + * case class GeometryUDT extends UserDefinedType[Geometry] { + * override def sqlType: DataType = ArrayType(ByteType, containsNull = false) + * // ... + * }}} + * Since 1.4.0, the sqlType of GeometryUDT is changed to BinaryType. This is a breaking change + * for reading old Parquet files. To read old Parquet files, users need to use "geoparquet" + * format and set legacyMode to true. + * @param parameters + * user provided parameters for reading GeoParquet files using `.option()` method, e.g. + * `spark.read.format("geoparquet").option("legacyMode", "true").load("path")` + * @return + * true if legacyMode is set to true, false otherwise + */ + def isLegacyMode(parameters: Map[String, String]): Boolean = + parameters.getOrElse("legacyMode", "false").toBoolean + + /** + * Parse GeoParquet file metadata from Parquet file metadata. Legacy parquet files do not + * contain GeoParquet file metadata, so we'll simply return an empty GeoParquetMetaData object + * when legacy mode is enabled. + * @param keyValueMetaData + * Parquet file metadata + * @param parameters + * user provided parameters for reading GeoParquet files + * @return + * GeoParquetMetaData object + */ + def parseGeoParquetMetaData( + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): GeoParquetMetaData = { + val isLegacyMode = GeoParquetUtils.isLegacyMode(parameters) + GeoParquetMetaData.parseKeyValueMetaData(keyValueMetaData).getOrElse { + if (isLegacyMode) { + GeoParquetMetaData(None, "", Map.empty) + } else { + throw new IllegalArgumentException("GeoParquet file does not contain valid geo metadata") + } + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala new file mode 100644 index 0000000000..90d6d962f4 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.io.api.Binary +import org.apache.parquet.io.api.RecordConsumer +import org.apache.sedona.common.utils.GeomUtils +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData.{GEOPARQUET_COVERING_KEY, GEOPARQUET_CRS_KEY, GEOPARQUET_VERSION_KEY, VERSION, createCoveringColumnMetadata} +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetWriteSupport.GeometryColumnInfo +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.json4s.{DefaultFormats, Extraction, JValue} +import org.json4s.jackson.JsonMethods.parse +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKBWriter + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * A Parquet [[WriteSupport]] implementation that writes Catalyst [[InternalRow]]s as Parquet + * messages. This class can write Parquet data in two modes: + * + * - Standard mode: Parquet data are written in standard format defined in parquet-format spec. + * - Legacy mode: Parquet data are written in legacy format compatible with Spark 1.4 and prior. + * + * This behavior can be controlled by SQL option `spark.sql.parquet.writeLegacyFormat`. The value + * of this option is propagated to this class by the `init()` method and its Hadoop configuration + * argument. + */ +class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. + // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access + // data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // Schema of the `InternalRow`s to be written + private var schema: StructType = _ + + // `ValueWriter`s for all fields of the schema + private var rootFieldWriters: Array[ValueWriter] = _ + + // The Parquet `RecordConsumer` to which all `InternalRow`s are written + private var recordConsumer: RecordConsumer = _ + + // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions + private var writeLegacyParquetFormat: Boolean = _ + + // Which parquet timestamp type to use when writing. + private var outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = _ + + // Reusable byte array used to write timestamps as Parquet INT96 values + private val timestampBuffer = new Array[Byte](12) + + // Reusable byte array used to write decimal values + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) + + private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_WRITE)) + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val int96RebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_WRITE)) + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(int96RebaseMode, "Parquet INT96") + + // A mapping from geometry field ordinal to bounding box. According to the geoparquet specification, + // "Geometry columns MUST be at the root of the schema", so we don't need to worry about geometry + // fields in nested structures. + private val geometryColumnInfoMap: mutable.Map[Int, GeometryColumnInfo] = mutable.Map.empty + + private var geoParquetVersion: Option[String] = None + private var defaultGeoParquetCrs: Option[JValue] = None + private val geoParquetColumnCrsMap: mutable.Map[String, Option[JValue]] = mutable.Map.empty + private val geoParquetColumnCoveringMap: mutable.Map[String, Covering] = mutable.Map.empty + + override def init(configuration: Configuration): WriteContext = { + val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) + this.schema = StructType.fromString(schemaString) + this.writeLegacyParquetFormat = { + // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation + assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) + configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean + } + + this.outputTimestampType = { + val key = SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key + assert(configuration.get(key) != null) + SQLConf.ParquetOutputTimestampType.withName(configuration.get(key)) + } + + this.rootFieldWriters = schema.zipWithIndex + .map { case (field, ordinal) => + makeWriter(field.dataType, Some(ordinal)) + } + .toArray[ValueWriter] + + if (geometryColumnInfoMap.isEmpty) { + throw new RuntimeException("No geometry column found in the schema") + } + + geoParquetVersion = configuration.get(GEOPARQUET_VERSION_KEY) match { + case null => Some(VERSION) + case version: String => Some(version) + } + defaultGeoParquetCrs = configuration.get(GEOPARQUET_CRS_KEY) match { + case null => + // If no CRS is specified, we write null to the crs metadata field. This is for compatibility with + // geopandas 0.10.0 and earlier versions, which requires crs field to be present. + Some(org.json4s.JNull) + case "" => None + case crs: String => Some(parse(crs)) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_CRS_KEY + "." + name)).foreach { + case "" => geoParquetColumnCrsMap.put(name, None) + case crs: String => geoParquetColumnCrsMap.put(name, Some(parse(crs))) + } + } + Option(configuration.get(GEOPARQUET_COVERING_KEY)).foreach { coveringColumnName => + if (geometryColumnInfoMap.size > 1) { + throw new IllegalArgumentException( + s"$GEOPARQUET_COVERING_KEY is ambiguous when there are multiple geometry columns." + + s"Please specify $GEOPARQUET_COVERING_KEY. for configured geometry column.") + } + val geometryColumnName = schema(geometryColumnInfoMap.keys.head).name + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(geometryColumnName, covering) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_COVERING_KEY + "." + name)).foreach { + coveringColumnName => + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(name, covering) + } + } + + val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) + val sparkSqlParquetRowMetadata = GeoParquetWriteSupport.getSparkSqlParquetRowMetadata(schema) + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> sparkSqlParquetRowMetadata) ++ { + if (datetimeRebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyDateTime" -> "") + } else { + None + } + } ++ { + if (int96RebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyINT96" -> "") + } else { + None + } + } + + logInfo(s"""Initialized Parquet WriteSupport with Catalyst schema: + |${schema.prettyJson} + |and corresponding Parquet message type: + |$messageType + """.stripMargin) + + new WriteContext(messageType, metadata.asJava) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + + override def finalizeWrite(): WriteSupport.FinalizedWriteContext = { + val metadata = new util.HashMap[String, String]() + if (geometryColumnInfoMap.nonEmpty) { + val primaryColumnIndex = geometryColumnInfoMap.keys.head + val primaryColumn = schema.fields(primaryColumnIndex).name + val columns = geometryColumnInfoMap.map { case (ordinal, columnInfo) => + val columnName = schema.fields(ordinal).name + val geometryTypes = columnInfo.seenGeometryTypes.toSeq + val bbox = if (geometryTypes.nonEmpty) { + Seq( + columnInfo.bbox.minX, + columnInfo.bbox.minY, + columnInfo.bbox.maxX, + columnInfo.bbox.maxY) + } else Seq(0.0, 0.0, 0.0, 0.0) + val crs = geoParquetColumnCrsMap.getOrElse(columnName, defaultGeoParquetCrs) + val covering = geoParquetColumnCoveringMap.get(columnName) + columnName -> GeometryFieldMetaData("WKB", geometryTypes, bbox, crs, covering) + }.toMap + val geoParquetMetadata = GeoParquetMetaData(geoParquetVersion, primaryColumn, columns) + val geoParquetMetadataJson = GeoParquetMetaData.toJson(geoParquetMetadata) + metadata.put("geo", geoParquetMetadataJson) + } + new FinalizedWriteContext(metadata) + } + + override def write(row: InternalRow): Unit = { + consumeMessage { + writeFields(row, schema, rootFieldWriters) + } + } + + private def writeFields( + row: InternalRow, + schema: StructType, + fieldWriters: Array[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + if (!row.isNullAt(i)) { + consumeField(schema(i).name, i) { + fieldWriters(i).apply(row, i) + } + } + i += 1 + } + } + + private def makeWriter(dataType: DataType, rootOrdinal: Option[Int] = None): ValueWriter = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getShort(ordinal)) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(dateRebaseFunc(row.getInt(ordinal))) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addFloat(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addDouble(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary( + Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) + + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + (row: SpecializedGetters, ordinal: Int) => + val micros = int96RebaseFunc(row.getLong(ordinal)) + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(micros) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + recordConsumer.addLong(timestampRebaseFunc(micros)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + val millis = GeoDateTimeUtils.microsToMillis(timestampRebaseFunc(micros)) + recordConsumer.addLong(millis) + } + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal))) + + case DecimalType.Fixed(precision, scale) => + makeDecimalWriter(precision, scale) + + case t: StructType => + val fieldWriters = t.map(_.dataType).map(makeWriter(_, None)).toArray[ValueWriter] + (row: SpecializedGetters, ordinal: Int) => + consumeGroup { + writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) + } + + case t: ArrayType => makeArrayWriter(t) + + case t: MapType => makeMapWriter(t) + + case GeometryUDT => + val geometryColumnInfo = rootOrdinal match { + case Some(ordinal) => + geometryColumnInfoMap.getOrElseUpdate(ordinal, new GeometryColumnInfo()) + case None => null + } + (row: SpecializedGetters, ordinal: Int) => { + val serializedGeometry = row.getBinary(ordinal) + val geom = GeometryUDT.deserialize(serializedGeometry) + val wkbWriter = new WKBWriter(GeomUtils.getDimension(geom)) + recordConsumer.addBinary(Binary.fromReusedByteArray(wkbWriter.write(geom))) + if (geometryColumnInfo != null) { + geometryColumnInfo.update(geom) + } + } + + case t: UserDefinedType[_] => makeWriter(t.sqlType) + + // TODO Adds IntervalType support + case _ => sys.error(s"Unsupported data type $dataType.") + } + } + + private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { + assert( + precision <= DecimalType.MAX_PRECISION, + s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") + + val numBytes = Decimal.minBytesForPrecision(precision) + + val int32Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addInteger(unscaledLong.toInt) + } + + val int64Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addLong(unscaledLong) + } + + val binaryWriterUsingUnscaledLong = + (row: SpecializedGetters, ordinal: Int) => { + // When the precision is low enough (<= 18) to squeeze the decimal value into a `Long`, we + // can build a fixed-length byte array with length `numBytes` using the unscaled `Long` + // value and the `decimalBuffer` for better performance. + val unscaled = row.getDecimal(ordinal, precision, scale).toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + + while (i < numBytes) { + decimalBuffer(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(decimalBuffer, 0, numBytes)) + } + + val binaryWriterUsingUnscaledBytes = + (row: SpecializedGetters, ordinal: Int) => { + val decimal = row.getDecimal(ordinal, precision, scale) + val bytes = decimal.toJavaBigDecimal.unscaledValue().toByteArray + val fixedLengthBytes = if (bytes.length == numBytes) { + // If the length of the underlying byte array of the unscaled `BigInteger` happens to be + // `numBytes`, just reuse it, so that we don't bother copying it to `decimalBuffer`. + bytes + } else { + // Otherwise, the length must be less than `numBytes`. In this case we copy contents of + // the underlying bytes with padding sign bytes to `decimalBuffer` to form the result + // fixed-length byte array. + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(fixedLengthBytes, 0, numBytes)) + } + + writeLegacyParquetFormat match { + // Standard mode, 1 <= precision <= 9, writes as INT32 + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer + + // Standard mode, 10 <= precision <= 18, writes as INT64 + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer + + // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong + + // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY + case _ => binaryWriterUsingUnscaledBytes + } + } + + def makeArrayWriter(arrayType: ArrayType): ValueWriter = { + val elementWriter = makeWriter(arrayType.elementType) + + def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < array.numElements()) { + consumeGroup { + // Only creates the element field if the current array element is not null. + if (!array.isNullAt(i)) { + consumeField(elementFieldName, 0) { + elementWriter.apply(array, i) + } + } + } + i += 1 + } + } + } + } + } + + def twoLevelArrayWriter(repeatedFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedFieldName, 0) { + var i = 0 + while (i < array.numElements()) { + elementWriter.apply(array, i) + i += 1 + } + } + } + } + } + + (writeLegacyParquetFormat, arrayType.containsNull) match { + case (legacyMode @ false, _) => + // Standard mode: + // + // group (LIST) { + // repeated group list { + // ^~~~ repeatedGroupName + // element; + // ^~~~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "list", elementFieldName = "element") + + case (legacyMode @ true, nullableElements @ true) => + // Legacy mode, with nullable elements: + // + // group (LIST) { + // optional group bag { + // ^~~ repeatedGroupName + // repeated array; + // ^~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "bag", elementFieldName = "array") + + case (legacyMode @ true, nullableElements @ false) => + // Legacy mode, with non-nullable elements: + // + // group (LIST) { + // repeated array; + // ^~~~~ repeatedFieldName + // } + twoLevelArrayWriter(repeatedFieldName = "array") + } + } + + private def makeMapWriter(mapType: MapType): ValueWriter = { + val keyWriter = makeWriter(mapType.keyType) + val valueWriter = makeWriter(mapType.valueType) + val repeatedGroupName = if (writeLegacyParquetFormat) { + // Legacy mode: + // + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // ^~~ repeatedGroupName + // required key; + // value; + // } + // } + "map" + } else { + // Standard mode: + // + // group (MAP) { + // repeated group key_value { + // ^~~~~~~~~ repeatedGroupName + // required key; + // value; + // } + // } + "key_value" + } + + (row: SpecializedGetters, ordinal: Int) => { + val map = row.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + + consumeGroup { + // Only creates the repeated field if the map is non-empty. + if (map.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < map.numElements()) { + consumeGroup { + consumeField("key", 0) { + keyWriter.apply(keyArray, i) + } + + // Only creates the "value" field if the value if non-empty + if (!map.valueArray().isNullAt(i)) { + consumeField("value", 1) { + valueWriter.apply(valueArray, i) + } + } + } + i += 1 + } + } + } + } + } + } + + private def consumeMessage(f: => Unit): Unit = { + recordConsumer.startMessage() + f + recordConsumer.endMessage() + } + + private def consumeGroup(f: => Unit): Unit = { + recordConsumer.startGroup() + f + recordConsumer.endGroup() + } + + private def consumeField(field: String, index: Int)(f: => Unit): Unit = { + recordConsumer.startField(field, index) + f + recordConsumer.endField(field, index) + } +} + +object GeoParquetWriteSupport { + class GeometryColumnInfo { + val bbox: GeometryColumnBoundingBox = new GeometryColumnBoundingBox() + + // GeoParquet column metadata has a `geometry_types` property, which contains a list of geometry types + // that are present in the column. + val seenGeometryTypes: mutable.Set[String] = mutable.Set.empty + + def update(geom: Geometry): Unit = { + bbox.update(geom) + // In case of 3D geometries, a " Z" suffix gets added (e.g. ["Point Z"]). + val hasZ = { + val coordinate = geom.getCoordinate + if (coordinate != null) !coordinate.getZ.isNaN else false + } + val geometryType = if (!hasZ) geom.getGeometryType else geom.getGeometryType + " Z" + seenGeometryTypes.add(geometryType) + } + } + + class GeometryColumnBoundingBox( + var minX: Double = Double.PositiveInfinity, + var minY: Double = Double.PositiveInfinity, + var maxX: Double = Double.NegativeInfinity, + var maxY: Double = Double.NegativeInfinity) { + def update(geom: Geometry): Unit = { + val env = geom.getEnvelopeInternal + minX = math.min(minX, env.getMinX) + minY = math.min(minY, env.getMinY) + maxX = math.max(maxX, env.getMaxX) + maxY = math.max(maxY, env.getMaxY) + } + } + + private def getSparkSqlParquetRowMetadata(schema: StructType): String = { + val fields = schema.fields.map { field => + field.dataType match { + case _: GeometryUDT => + // Don't write the GeometryUDT type to the Parquet metadata. Write the type as binary for maximum + // compatibility. + field.copy(dataType = BinaryType) + case _ => field + } + } + StructType(fields).json + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala new file mode 100644 index 0000000000..aadca3a60f --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoSchemaMergeUtils { + + def mergeSchemasInParallel( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus], + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]) + : Option[StructType] = { + val serializedConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(parameters)) + + // !! HACK ALERT !! + // Here is a hack for Parquet, but it can be used by Orc as well. + // + // Parquet requires `FileStatus`es to read footers. + // Here we try to send cached `FileStatus`es to executor side to avoid fetching them again. + // However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = files.map(f => (f.getPath.toString, f.getLen)) + + // Set the number of partitions to prevent following schema reads from generating many tasks + // in case of a small number of orc files. + val numParallelism = Math.min( + Math.max(partialFileStatusInfo.size, 1), + sparkSession.sparkContext.defaultParallelism) + + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + + // Issues a Spark job to read Parquet/ORC schema in parallel. + val partiallyMergedSchemas = + sparkSession.sparkContext + .parallelize(partialFileStatusInfo, numParallelism) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + val schemas = schemaReader(fakeFileStatuses, serializedConf.value, ignoreCorruptFiles) + + if (schemas.isEmpty) { + Iterator.empty + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergedSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } + } + .collect() + + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 0000000000..43e1ababb7 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 0000000000..1fe2faa2e0 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + val filePath = partitionedFile.filePath + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath), configuration)) + .getFooter + .getFileMetaData + .getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 0000000000..b86ab7a399 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters) + } + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" + + // This is for compatibility with Spark 3.0. Spark 3.3 does not have this method + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 0000000000..6a25e4530c --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + getPushedDataFilters, + getPartitionFilters, + getDataFilters) + } + + // The following methods uses reflection to address compatibility issues for Spark 3.0 ~ 3.2 + + private def getPushedDataFilters: Array[Filter] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("pushedDataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Array[Filter]] + } catch { + case _: NoSuchFieldException => + Array.empty + } + } + + private def getPartitionFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("partitionFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } + + private def getDataFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("dataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 0000000000..845764fae5 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-3.2/src/test/resources/log4j2.properties b/spark/spark-3.2/src/test/resources/log4j2.properties new file mode 100644 index 0000000000..5f89859463 --- /dev/null +++ b/spark/spark-3.2/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 0000000000..421890c700 --- /dev/null +++ b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT, nullable = false), + StructField("geomColumn2", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala new file mode 100644 index 0000000000..8f3cc3f1e5 --- /dev/null +++ b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.generateTestData +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.readGeoParquetMetaDataMap +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.writeTestDataAsGeoParquet +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.locationtech.jts.geom.Coordinate +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.geom.GeometryFactory +import org.scalatest.prop.TableDrivenPropertyChecks + +import java.io.File +import java.nio.file.Files + +class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrivenPropertyChecks { + + val tempDir: String = + Files.createTempDirectory("sedona_geoparquet_test_").toFile.getAbsolutePath + val geoParquetDir: String = tempDir + "/geoparquet" + var df: DataFrame = _ + var geoParquetDf: DataFrame = _ + var geoParquetMetaDataMap: Map[Int, Seq[GeoParquetMetaData]] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + df = generateTestData(sparkSession) + writeTestDataAsGeoParquet(df, geoParquetDir) + geoParquetDf = sparkSession.read.format("geoparquet").load(geoParquetDir) + geoParquetMetaDataMap = readGeoParquetMetaDataMap(geoParquetDir) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) + + describe("GeoParquet spatial filter push down tests") { + it("Push down ST_Contains") { + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Contains(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Covers") { + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Covers(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Within") { + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Within(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_CoveredBy") { + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_CoveredBy(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_Intersects") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Intersects(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + + it("Push down ST_Equals") { + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((-16 -16, -16 -14, -14 -14, -14 -16, -16 -16))'))", + Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-15 -15)'))", Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-16 -16)'))", Seq(2)) + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + forAll(Table("<", "<=")) { op => + it(s"Push down ST_Distance $op d") { + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 1", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 5", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (3 4)')) $op 1", Seq(1)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 7.1", Seq(0, 1, 2, 3)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) $op 1", Seq(2)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 2", + Seq.empty) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 3", + Seq(0, 1, 2, 3)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('LINESTRING (17 17, 18 18)')) $op 1", + Seq(1)) + } + } + + it("Push down And(filters...)") { + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + } + + it("Push down Or(filters...)") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom) OR ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0, 1)) + testFilter( + "ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) <= 1 OR ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1, 2)) + } + + it("Ignore negated spatial filters") { + testFilter( + "NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) AND NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) OR NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + } + + it("Mixed spatial filter with other filter") { + testFilter( + "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + } + + /** + * Test filter push down using specified query condition, and verify if the pushed down filter + * prunes regions as expected. We'll also verify the correctness of query results. + * @param condition + * SQL query condition + * @param expectedPreservedRegions + * Regions that should be preserved after filter push down + */ + private def testFilter(condition: String, expectedPreservedRegions: Seq[Int]): Unit = { + val dfFiltered = geoParquetDf.where(condition) + val preservedRegions = getPushedDownSpatialFilter(dfFiltered) match { + case Some(spatialFilter) => resolvePreservedRegions(spatialFilter) + case None => (0 until 4) + } + assert(expectedPreservedRegions == preservedRegions) + val expectedResult = + df.where(condition).orderBy("region", "id").select("region", "id").collect() + val actualResult = dfFiltered.orderBy("region", "id").select("region", "id").collect() + assert(expectedResult sameElements actualResult) + } + + private def getPushedDownSpatialFilter(df: DataFrame): Option[GeoParquetSpatialFilter] = { + val executedPlan = df.queryExecution.executedPlan + val fileSourceScanExec = executedPlan.find(_.isInstanceOf[FileSourceScanExec]) + assert(fileSourceScanExec.isDefined) + val fileFormat = fileSourceScanExec.get.asInstanceOf[FileSourceScanExec].relation.fileFormat + assert(fileFormat.isInstanceOf[GeoParquetFileFormat]) + fileFormat.asInstanceOf[GeoParquetFileFormat].spatialFilter + } + + private def resolvePreservedRegions(spatialFilter: GeoParquetSpatialFilter): Seq[Int] = { + geoParquetMetaDataMap + .filter { case (_, metaDataList) => + metaDataList.exists(metadata => spatialFilter.evaluate(metadata.columns)) + } + .keys + .toSeq + } +} + +object GeoParquetSpatialFilterPushDownSuite { + case class TestDataItem(id: Int, region: Int, geom: Geometry) + + /** + * Generate test data centered at (0, 0). The entire dataset was divided into 4 quadrants, each + * with a unique region ID. The dataset contains 4 points and 4 polygons in each quadrant. + * @param sparkSession + * SparkSession object + * @return + * DataFrame containing test data + */ + def generateTestData(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + val regionCenters = Seq((-10, 10), (10, 10), (-10, -10), (10, -10)) + val testData = regionCenters.zipWithIndex.flatMap { case ((x, y), i) => + generateTestDataForRegion(i, x, y) + } + testData.toDF() + } + + private def generateTestDataForRegion(region: Int, centerX: Double, centerY: Double) = { + val factory = new GeometryFactory() + val points = Seq( + factory.createPoint(new Coordinate(centerX - 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX - 5, centerY - 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY - 5))) + val polygons = points.map { p => + val envelope = p.getEnvelopeInternal + envelope.expandBy(1) + factory.toGeometry(envelope) + } + (points ++ polygons).zipWithIndex.map { case (g, i) => TestDataItem(i, region, g) } + } + + /** + * Write the test dataframe as GeoParquet files. Each region is written to a separate file. + * We'll test spatial filter push down by examining which regions were preserved/pruned by + * evaluating the pushed down spatial filters + * @param testData + * dataframe containing test data + * @param path + * path to write GeoParquet files + */ + def writeTestDataAsGeoParquet(testData: DataFrame, path: String): Unit = { + testData.coalesce(1).write.partitionBy("region").format("geoparquet").save(path) + } + + /** + * Load GeoParquet metadata for each region. Note that there could be multiple files for each + * region, thus each region ID was associated with a list of GeoParquet metadata. + * @param path + * path to directory containing GeoParquet files + * @return + * Map of region ID to list of GeoParquet metadata + */ + def readGeoParquetMetaDataMap(path: String): Map[Int, Seq[GeoParquetMetaData]] = { + (0 until 4).map { k => + val geoParquetMetaDataSeq = readGeoParquetMetaDataByRegion(path, k) + k -> geoParquetMetaDataSeq + }.toMap + } + + private def readGeoParquetMetaDataByRegion( + geoParquetSavePath: String, + region: Int): Seq[GeoParquetMetaData] = { + val parquetFiles = new File(geoParquetSavePath + s"/region=$region") + .listFiles() + .filter(_.getName.endsWith(".parquet")) + parquetFiles.flatMap { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + GeoParquetMetaData.parseKeyValueMetaData(metadata) + } + } +} diff --git a/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 0000000000..2da12eceb0 --- /dev/null +++ b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.DataFrame +import org.scalatest.{BeforeAndAfterAll, FunSpec} + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } +} diff --git a/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala new file mode 100644 index 0000000000..ccfd560c84 --- /dev/null +++ b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -0,0 +1,748 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.Row +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.datasources.parquet.{Covering, GeoParquetMetaData, ParquetReadSupport} +import org.apache.spark.sql.functions.{col, expr} +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point, ST_PolygonFromEnvelope} +import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.json4s.jackson.parseJson +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKTReader +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.util.Collections +import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConverters._ + +class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation1: String = resourceFolder + "geoparquet/example1.parquet" + val geoparquetdatalocation2: String = resourceFolder + "geoparquet/example2.parquet" + val geoparquetdatalocation3: String = resourceFolder + "geoparquet/example3.parquet" + val geoparquetdatalocation4: String = resourceFolder + "geoparquet/example-1.0.0-beta.1.parquet" + val geoparquetdatalocation5: String = resourceFolder + "geoparquet/example-1.1.0.parquet" + val legacyparquetdatalocation: String = + resourceFolder + "parquet/legacy-parquet-nested-columns.snappy.parquet" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(geoparquetoutputlocation)) + + describe("GeoParquet IO tests") { + it("GEOPARQUET Test example1 i.e. naturalearth_lowers dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = df.collect()(0) + assert(rows.getAs[Long]("pop_est") == 920938) + assert(rows.getAs[String]("continent") == "Oceania") + assert(rows.getAs[String]("name") == "Fiji") + assert(rows.getAs[String]("iso_a3") == "FJI") + assert(rows.getAs[Double]("gdp_md_est") == 8374.0) + assert( + rows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample1.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample1.parquet") + val newrows = df2.collect()(0) + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + } + it("GEOPARQUET Test example2 i.e. naturalearth_citie dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation2) + val rows = df.collect()(0) + assert(rows.getAs[String]("name") == "Vatican City") + assert( + rows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample2.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample2.parquet") + val newrows = df2.collect()(0) + assert(newrows.getAs[String]("name") == "Vatican City") + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + } + it("GEOPARQUET Test example3 i.e. nybb dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation3) + val rows = df.collect()(0) + assert(rows.getAs[Long]("BoroCode") == 5) + assert(rows.getAs[String]("BoroName") == "Staten Island") + assert(rows.getAs[Double]("Shape_Leng") == 330470.010332) + assert(rows.getAs[Double]("Shape_Area") == 1.62381982381e9) + assert(rows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample3.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample3.parquet") + val newrows = df2.collect()(0) + assert( + newrows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + } + it("GEOPARQUET Test example-1.0.0-beta.1.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample4.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + + val parquetFiles = + new File(geoParquetSavePath).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + assert(columnName == "geometry") + val geomTypes = (geo \ "columns" \ "geometry" \ "geometry_types").extract[Seq[String]] + assert(geomTypes.nonEmpty) + val sparkSqlRowMetadata = metadata.get(ParquetReadSupport.SPARK_METADATA_KEY) + assert(!sparkSqlRowMetadata.contains("GeometryUDT")) + } + } + it("GEOPARQUET Test example-1.1.0.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation5) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample5.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + } + + it("GeoParquet with multiple geometry columns") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")), + Row( + 2, + wktReader.read("POINT Z(1 2 3)"), + wktReader.read("POLYGON Z((0 0 2, 1 0 2, 1 1 2, 0 1 2, 0 0 2))")), + Row( + 3, + wktReader.read("MULTIPOINT (0 0, 1 1, 2 2)"), + wktReader.read("MULTILINESTRING ((0 0, 1 1), (2 2, 3 3))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + // Find parquet files in geoParquetSavePath directory and validate their metadata + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + assert(version == GeoParquetMetaData.VERSION) + val g0Types = (geo \ "columns" \ "g0" \ "geometry_types").extract[Seq[String]] + val g1Types = (geo \ "columns" \ "g1" \ "geometry_types").extract[Seq[String]] + assert(g0Types.sorted == Seq("Point", "Point Z", "MultiPoint").sorted) + assert(g1Types.sorted == Seq("Polygon", "Polygon Z", "MultiLineString").sorted) + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == org.json4s.JNull) + } + + // Read GeoParquet with multiple geometry columns + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(df2.schema.fields(2).dataType.isInstanceOf[GeometryUDT]) + val rows = df2.collect() + assert(testData.length == rows.length) + assert(rows(0).getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(rows(0).getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + + it("GeoParquet save should work with empty dataframes") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/empty.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val g0Types = (geo \ "columns" \ "g" \ "geometry_types").extract[Seq[String]] + val g0BBox = (geo \ "columns" \ "g" \ "bbox").extract[Seq[Double]] + assert(g0Types.isEmpty) + assert(g0BBox == Seq(0.0, 0.0, 0.0, 0.0)) + } + } + + it("GeoParquet save should work with snake_case column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/snake_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geom_column") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should work with camelCase column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geomColumn", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/camel_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geomColumn") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should write user specified version and crs to geo metadata") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + // This CRS is taken from https://proj.org/en/9.3/specifications/projjson.html#geographiccrs + // with slight modification. + val projjson = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + var geoParquetSavePath = geoparquetoutputlocation + "/gp_custom_meta.parquet" + df.write + .format("geoparquet") + .option("geoparquet.version", "10.9.8") + .option("geoparquet.crs", projjson) + .mode("overwrite") + .save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + val columnName = (geo \ "primary_column").extract[String] + assert(version == "10.9.8") + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs.isInstanceOf[org.json4s.JObject]) + assert(crs == parseJson(projjson)) + } + + // Setting crs to null explicitly + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val df3 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df3.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNull) + } + + // Setting crs to "" to omit crs + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNothing) + } + } + + it("GeoParquet save should support specifying per-column CRS") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + + val projjson0 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + + val projjson1 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "Monte Mario (Rome)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "Monte Mario (Rome)", + | "ellipsoid": { + | "name": "International 1924", + | "semi_major_axis": 6378388, + | "inverse_flattening": 297 + | }, + | "prime_meridian": { + | "name": "Rome", + | "longitude": 12.4523333333333 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Geodesy, onshore minerals management.", + | "area": "Italy - onshore and offshore; San Marino, Vatican City State.", + | "bbox": { + | "south_latitude": 34.76, + | "west_longitude": 5.93, + | "north_latitude": 47.1, + | "east_longitude": 18.99 + | }, + | "id": { + | "authority": "EPSG", + | "code": 4806 + | } + |} + |""".stripMargin + + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms_with_custom_crs.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == parseJson(projjson1)) + } + + // Write without fallback CRS for g0 + df.write + .format("geoparquet") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == parseJson(projjson1)) + } + + // Fallback CRS is omitting CRS + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNothing) + assert(g1Crs == parseJson(projjson1)) + } + + // Write with CRS, explicitly set CRS to null for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "null") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNull) + } + + // Write with CRS, explicitly omit CRS for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNothing) + } + } + + it("GeoParquet load should raise exception when loading plain parquet files") { + val e = intercept[SparkException] { + sparkSession.read.format("geoparquet").load(resourceFolder + "geoparquet/plain.parquet") + } + assert(e.getMessage.contains("does not contain valid geo metadata")) + } + + it("GeoParquet load with spatial predicates") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = + df.where(ST_Intersects(ST_Point(35.174722, -6.552465), col("geometry"))).collect() + assert(rows.length == 1) + assert(rows(0).getAs[String]("name") == "Tanzania") + } + + it("Filter push down for nested columns") { + import sparkSession.implicits._ + + // Prepare multiple GeoParquet files with bbox metadata. There should be 10 files in total, each file contains + // 1000 records. + val dfIds = (0 until 10000).toDF("id") + val dfGeom = dfIds + .withColumn( + "bbox", + expr("struct(id as minx, id as miny, id + 1 as maxx, id + 1 as maxy)")) + .withColumn("geom", expr("ST_PolygonFromEnvelope(id, id, id + 1, id + 1)")) + .withColumn("part_id", expr("CAST(id / 1000 AS INTEGER)")) + .coalesce(1) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_bbox.parquet" + dfGeom.write + .partitionBy("part_id") + .format("geoparquet") + .mode("overwrite") + .save(geoParquetSavePath) + + val sparkListener = new SparkListener() { + val recordsRead = new AtomicLong(0) + + def reset(): Unit = recordsRead.set(0) + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead + this.recordsRead.getAndAdd(recordsRead) + } + } + + sparkSession.sparkContext.addSparkListener(sparkListener) + try { + val df = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + + // This should trigger filter push down to Parquet and only read one of the files. The number of records read + // should be less than 1000. + df.where("bbox.minx > 6000 and bbox.minx < 6600").count() + assert(sparkListener.recordsRead.get() <= 1000) + + // Reading these files using spatial filter. This should only read two of the files. + sparkListener.reset() + df.where(ST_Intersects(ST_PolygonFromEnvelope(7010, 7010, 8100, 8100), col("geom"))) + .count() + assert(sparkListener.recordsRead.get() <= 2000) + } finally { + sparkSession.sparkContext.removeSparkListener(sparkListener) + } + } + + it("Ready legacy parquet files written by Apache Sedona <= 1.3.1-incubating") { + val df = sparkSession.read + .format("geoparquet") + .option("legacyMode", "true") + .load(legacyparquetdatalocation) + val rows = df.collect() + assert(rows.nonEmpty) + rows.foreach { row => + assert(row.getAs[AnyRef]("geom").isInstanceOf[Geometry]) + assert(row.getAs[AnyRef]("struct_geom").isInstanceOf[Row]) + val structGeom = row.getAs[Row]("struct_geom") + assert(structGeom.getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(structGeom.getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + } + + it("GeoParquet supports writing covering metadata") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geometry", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geometry", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + } + + it("GeoParquet supports writing covering metadata for multiple columns") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geom1", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov1", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + .withColumn("geom2", expr("ST_Point(10 * id, 10 * id + 1)")) + .withColumn( + "test_cov2", + expr( + "struct(10 * id AS xmin, 10 * id + 1 AS ymin, 10 * id AS xmax, 10 * id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering.geom1", "test_cov1") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + Seq(("geom1", "test_cov1"), ("geom2", "test_cov2")).foreach { + case (geomName, coveringName) => + val coveringJsValue = geo \ "columns" \ geomName \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq(coveringName, "xmin")) + assert(covering.bbox.ymin == Seq(coveringName, "ymin")) + assert(covering.bbox.xmax == Seq(coveringName, "xmax")) + assert(covering.bbox.ymax == Seq(coveringName, "ymax")) + } + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + assert(geo \ "columns" \ "geom1" \ "covering" == org.json4s.JNothing) + val coveringJsValue = geo \ "columns" \ "geom2" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov2", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov2", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov2", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov2", "ymax")) + } + } + } + + def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = { + val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + body(geo) + } + } +} diff --git a/spark/spark-3.3/.gitignore b/spark/spark-3.3/.gitignore new file mode 100644 index 0000000000..1cc6c4a1f6 --- /dev/null +++ b/spark/spark-3.3/.gitignore @@ -0,0 +1,12 @@ +/target/ +/.settings/ +/.classpath +/.project +/dependency-reduced-pom.xml +/doc/ +/.idea/ +*.iml +/latest/ +/spark-warehouse/ +/metastore_db/ +*.log diff --git a/spark/spark-3.3/pom.xml b/spark/spark-3.3/pom.xml new file mode 100644 index 0000000000..d40bea63f9 --- /dev/null +++ b/spark/spark-3.3/pom.xml @@ -0,0 +1,145 @@ + + + + 4.0.0 + + org.apache.sedona + sedona-spark-parent-${spark.compat.version}_${scala.compat.version} + 1.6.1-SNAPSHOT + ../pom.xml + + sedona-spark-3.3_${scala.compat.version} + + ${project.groupId}:${project.artifactId} + A cluster computing system for processing large-scale spatial data: SQL API for Spark 3.3. + http://sedona.apache.org/ + jar + + + false + + + + + org.apache.sedona + sedona-common + ${project.version} + + + com.fasterxml.jackson.core + * + + + + + org.apache.sedona + sedona-spark-common-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + + + org.apache.hadoop + hadoop-client + + + org.apache.logging.log4j + log4j-1.2-api + + + org.geotools + gt-main + + + org.geotools + gt-referencing + + + org.geotools + gt-epsg-hsql + + + org.geotools + gt-geotiff + + + org.geotools + gt-coverage + + + org.geotools + gt-arcgrid + + + org.locationtech.jts + jts-core + + + org.wololo + jts2geojson + + + com.fasterxml.jackson.core + * + + + + + org.scala-lang + scala-library + + + org.scala-lang.modules + scala-collection-compat_${scala.compat.version} + + + org.scalatest + scalatest_${scala.compat.version} + + + org.mockito + mockito-inline + + + + src/main/scala + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..e5f994e203 --- /dev/null +++ b/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,2 @@ +org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala new file mode 100644 index 0000000000..4348325570 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.RebaseDateTime +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.util.Utils + +import scala.util.Try + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDataSourceUtils { + + val PARQUET_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.datetimeRebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + val PARQUET_INT96_REBASE_MODE_IN_READ = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.int96RebaseModeInRead", + "spark.sql.legacy.parquet.datetimeRebaseModeInRead") + val PARQUET_INT96_REBASE_MODE_IN_WRITE = firstAvailableConf( + "spark.sql.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.int96RebaseModeInWrite", + "spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + + private def firstAvailableConf(confs: String*): String = { + confs.find(c => Try(SQLConf.get.getConfString(c)).isSuccess).get + } + + def datetimeRebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def int96RebaseMode( + lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)) + .map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + } + .getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + + def creteDateRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteDateRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def creteTimestampRebaseFuncInRead( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + + def creteTimestampRebaseFuncInWrite( + rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => + micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianMicros + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala new file mode 100644 index 0000000000..bf3c2a19a9 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoDateTimeUtils { + + /** + * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have + * microseconds precision, so this conversion is lossy. + */ + def microsToMillis(micros: Long): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floorDiv(micros, MICROS_PER_MILLIS) + } + + /** + * Converts milliseconds since the epoch to microseconds. + */ + def millisToMicros(millis: Long): Long = { + Math.multiplyExact(millis, MICROS_PER_MILLIS) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala new file mode 100644 index 0000000000..702c6f31fb --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.readParquetFootersInParallel +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +import java.net.URI +import scala.collection.JavaConverters._ +import scala.util.Failure +import scala.util.Try + +class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) + extends ParquetFileFormat + with GeoParquetFileFormatBase + with FileFormat + with DataSourceRegister + with Logging + with Serializable { + + def this() = this(None) + + override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat] && + other.asInstanceOf[GeoParquetFileFormat].spatialFilter == spatialFilter + + override def hashCode(): Int = getClass.hashCode() + + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = + new GeoParquetFileFormat(Some(spatialFilter)) + + override def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + GeoParquetUtils.inferSchema(sparkSession, parameters, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) + + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[OutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo( + "Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo( + "Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[OutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // This metadata is useful for keeping UDTs like Vector/Matrix. + ParquetWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + + conf.set( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + + try { + val fieldIdWriteEnabled = + SQLConf.get.getConfString("spark.sql.parquet.fieldId.write.enabled") + conf.set("spark.sql.parquet.fieldId.write.enabled", fieldIdWriteEnabled) + } catch { + case e: NoSuchElementException => () + } + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) + + // SPARK-15719: Disables writing Parquet summary files by default. + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) + } + + if (ParquetOutputFormat.getJobSummaryLevel(conf) != JobSummaryLevel.NONE + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning( + s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") + } + + conf.set(ParquetOutputFormat.WRITE_SUPPORT_CLASS, classOf[GeoParquetWriteSupport].getName) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } + } + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, requiredSchema.json) + hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader: Boolean = + sqlConf.parquetVectorizedReaderEnabled && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + assert(file.partitionValues.numFields == partitionSchema.size) + + val filePath = new Path(new URI(file.filePath)) + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + filePath, + file.start, + file.start + file.length, + file.length, + Array.empty, + null) + + val sharedConf = broadcastedHadoopConf.value.value + + val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new GeoParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(_)) + .reduceOption(FilterApi.and) + } else { + None + } + + // Prune file scans using pushed down spatial filters and per-column bboxes in geoparquet metadata + val shouldScanFile = + GeoParquetMetaData.parseKeyValueMetaData(footerFileMetaData.getKeyValueMetaData).forall { + metadata => spatialFilter.forall(_.evaluate(metadata.columns)) + } + if (!shouldScanFile) { + // The entire file is pruned so that we don't need to scan this file. + Seq.empty[InternalRow].iterator + } else { + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_READ)) + val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_READ)) + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = + new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + logWarning( + s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr") + } + logDebug(s"Falling back to parquet-mr") + // ParquetRecordReader returns InternalRow + val readSupport = new GeoParquetReadSupport( + convertTz, + enableVectorizedReader = false, + datetimeRebaseMode, + int96RebaseMode, + options) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val iter = new RecordReaderIterator[InternalRow](reader) + // SPARK-23457 Register a task completion listener before `initialization`. + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + reader.initialize(split, hadoopAttemptContext) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } + } + } + + override def supportDataType(dataType: DataType): Boolean = super.supportDataType(dataType) + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = false +} + +object GeoParquetFileFormat extends Logging { + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block + * of that file. Thus we only need to retrieve the location of the last block. However, + * Hadoop `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on S3 + * nodes). + */ + def mergeSchemasInParallel( + parameters: Map[String, String], + filesToTouch: Seq[FileStatus], + sparkSession: SparkSession): Option[StructType] = { + val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + + val reader = (files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) => { + readParquetFootersInParallel(conf, files, ignoreCorruptFiles) + .map { footer => + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val keyValueMetaData = footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData + val converter = new GeoParquetToSparkSchemaConverter( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + parameters = parameters) + readSchemaFromFooter(footer, keyValueMetaData, converter, parameters) + } + } + + GeoSchemaMergeUtils.mergeSchemasInParallel(sparkSession, parameters, filesToTouch, reader) + } + + private def readSchemaFromFooter( + footer: Footer, + keyValueMetaData: java.util.Map[String, String], + converter: GeoParquetToSparkSchemaConverter, + parameters: Map[String, String]): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData.getKeyValueMetaData.asScala.toMap + .get(ParquetReadSupport.SPARK_METADATA_KEY) + .flatMap(schema => deserializeSchemaString(schema, keyValueMetaData, parameters)) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString( + schemaString: String, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + val schemaOpt = Try(DataType.fromJson(schemaString).asInstanceOf[StructType]) + .recover { case _: Throwable => + logInfo( + "Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + LegacyTypeStringParser.parseString(schemaString).asInstanceOf[StructType] + } + .recoverWith { case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", + cause) + Failure(cause) + } + .toOption + + schemaOpt.map(schema => + replaceGeometryColumnWithGeometryUDT(schema, keyValueMetaData, parameters)) + } + + private def replaceGeometryColumnWithGeometryUDT( + schema: StructType, + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): StructType = { + val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + val fields = schema.fields.map { field => + field.dataType match { + case _: BinaryType if geoParquetMetaData.columns.contains(field.name) => + field.copy(dataType = GeometryUDT) + case _ => field + } + } + StructType(fields) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala new file mode 100644 index 0000000000..d44f679058 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} +import java.util.Locale + +import scala.collection.JavaConverters.asScalaBufferConverter + +import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.SparkFilterApi._ +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ + +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources +import org.apache.spark.unsafe.types.UTF8String + +// Needed by Sedona to support Spark 3.0 - 3.3 +/** + * Some utility function to convert Spark data source filters to Parquet filters. + */ +class GeoParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + // A map which contains parquet field name and data type, if predicate push down applies. + // + // Each key in `nameToParquetField` represents a column; `dots` are used as separators for + // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. + // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + private val nameToParquetField: Map[String, ParquetPrimitiveField] = { + // Recursively traverse the parquet schema to get primitive fields that can be pushed-down. + // `parentFieldNames` is used to keep track of the current nested level when traversing. + def getPrimitiveFields( + fields: Seq[Type], + parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = { + fields.flatMap { + case p: PrimitiveType => + Some( + ParquetPrimitiveField( + fieldNames = parentFieldNames :+ p.getName, + fieldType = ParquetSchemaType( + p.getOriginalType, + p.getPrimitiveTypeName, + p.getTypeLength, + p.getDecimalMetadata))) + // Note that when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are removed. + case _ => None + } + } + + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + (field.fieldNames.toSeq.quoted, field) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) + } + } + + /** + * Holds a single primitive field information stored in the underlying parquet file. + * + * @param fieldNames + * a field name as an array of string multi-identifier in parquet file + * @param fieldType + * field type related info in parquet file + */ + private case class ParquetPrimitiveField( + fieldNames: Array[String], + fieldType: ParquetSchemaType) + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + length: Int, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) + + private def dateToDays(date: Any): Int = date match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case ld: LocalDate => DateTimeUtils.localDateToDays(ld) + } + + private def timestampToMicros(v: Any): JLong = v match { + case i: Instant => DateTimeUtils.instantToMicros(i) + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + } + + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + + private def timestampToMillis(v: Any): JLong = { + val micros = timestampToMicros(v) + val millis = GeoDateTimeUtils.microsToMillis(micros) + millis.asInstanceOf[JLong] + } + + private val makeEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) + + // Binary.fromString and Binary.fromByteArray don't accept null values + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.eq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeNotEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetBooleanType => + (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMicros).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => + FilterApi.notEq(longColumn(n), Option(v).map(timestampToMillis).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + } + + private val makeLt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeLtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGt + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + private val makeGtEq + : PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) + case ParquetFloatType => + (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) + case ParquetDoubleType => + (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) + + case ParquetStringType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case ParquetDateType if pushDownDate => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: Array[String], v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) + } + + // Returns filters that can be pushed down when reading Parquet files. + def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = { + filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true)) + } + + private def convertibleFiltersHelper( + predicate: sources.Filter, + canPartialPushDown: Boolean): Option[sources.Filter] = { + predicate match { + case sources.And(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult)) + case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) + case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) + case _ => None + } + + case sources.Or(left, right) => + val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) { + None + } else { + Some(sources.Or(leftResultOptional.get, rightResultOptional.get)) + } + case sources.Not(pred) => + val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) + resultOptional.map(sources.Not) + + case other => + if (createFilter(other).isDefined) { + Some(other) + } else { + None + } + } + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(predicate: sources.Filter): Option[FilterPredicate] = { + createFilterHelper(predicate, canPartialPushDownConjuncts = true) + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + private def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToParquetField(name).fieldType match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => + value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + private def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + private def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) + } + + /** + * @param predicate + * the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts + * whether a subset of conjuncts of predicates can be pushed down safely. Pushing ONLY one + * side of AND down is safe to do at the top level or none of its ancestors is NOT and OR. + * @return + * the Parquet-native filter predicates that are eligible for pushdown. + */ + private def createFilterHelper( + predicate: sources.Filter, + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `PruneFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + makeNotEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => + makeLt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeLtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => + makeGt + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => + makeGtEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, value)) + + case sources.And(lhs, rhs) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = + createFilterHelper(lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(rhs, canPartialPushDownConjuncts) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) + case _ => None + } + + case sources.Or(lhs, rhs) => + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // (a1 AND a2) OR (b1 AND b2), + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + for { + lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts) + rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilterHelper(pred, canPartialPushDownConjuncts = false) + .map(FilterApi.not) + + case sources.In(name, values) + if canMakeFilterOn(name, values.head) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct + .flatMap { v => + makeEq + .lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldNames, v)) + } + .reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => + Option(prefix).map { v => + FilterApi.userDefined( + binaryColumn(nameToParquetField(name).fieldNames), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + value != null && UTF8String + .fromBytes(value.getBytes) + .startsWith(UTF8String.fromBytes(strToBinary.getBytes)) + } + }) + } + + case _ => None + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala new file mode 100644 index 0000000000..a3c2be5d22 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +import java.time.ZoneId +import java.util.{Locale, Map => JMap} +import scala.collection.JavaConverters._ + +/** + * A Parquet [[ReadSupport]] implementation for reading Parquet records as Catalyst + * [[InternalRow]]s. + * + * The API interface of [[ReadSupport]] is a little bit over complicated because of historical + * reasons. In older versions of parquet-mr (say 1.6.0rc3 and prior), [[ReadSupport]] need to be + * instantiated and initialized twice on both driver side and executor side. The [[init()]] method + * is for driver side initialization, while [[prepareForRead()]] is for executor side. However, + * starting from parquet-mr 1.6.0, it's no longer the case, and [[ReadSupport]] is only + * instantiated and initialized on executor side. So, theoretically, now it's totally fine to + * combine these two methods into a single initialization method. The only reason (I could think + * of) to still have them here is for parquet-mr API backwards-compatibility. + * + * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from + * [[init()]] to [[prepareForRead()]], but use a private `var` for simplicity. + */ +class GeoParquetReadSupport( + override val convertTz: Option[ZoneId], + enableVectorizedReader: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends ParquetReadSupport + with Logging { + private var catalystRequestedSchema: StructType = _ + + /** + * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record + * readers. Responsible for figuring out Parquet requested schema used for column pruning. + */ + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + catalystRequestedSchema = { + val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + assert(schemaString != null, "Parquet requested schema not set.") + StructType.fromString(schemaString) + } + + val caseSensitive = + conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) + val schemaPruningEnabled = conf.getBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) + val parquetFileSchema = context.getFileSchema + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema( + parquetFileSchema, + catalystRequestedSchema, + caseSensitive) + + // We pass two schema to ParquetRecordMaterializer: + // - parquetRequestedSchema: the schema of the file data we want to read + // - catalystRequestedSchema: the schema of the rows we want to return + // The reader is responsible for reconciling the differences between the two. + val parquetRequestedSchema = if (schemaPruningEnabled && !enableVectorizedReader) { + // Parquet-MR reader requires that parquetRequestedSchema include only those fields present + // in the underlying parquetFileSchema. Therefore, we intersect the parquetClippedSchema + // with the parquetFileSchema + GeoParquetReadSupport + .intersectParquetGroups(parquetClippedSchema, parquetFileSchema) + .map(groupType => new MessageType(groupType.getName, groupType.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + // Spark's vectorized reader only support atomic types currently. It also skip fields + // in parquetRequestedSchema which are not present in the file. + parquetClippedSchema + } + logDebug( + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$parquetFileSchema + |Parquet clipped schema: + |$parquetClippedSchema + |Parquet requested schema: + |$parquetRequestedSchema + |Catalyst requested schema: + |${catalystRequestedSchema.treeString} + """.stripMargin) + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + val parquetRequestedSchema = readContext.getRequestedSchema + new GeoParquetRecordMaterializer( + parquetRequestedSchema, + GeoParquetReadSupport.expandUDT(catalystRequestedSchema), + new GeoParquetToSparkSchemaConverter(keyValueMetaData, conf, parameters), + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters) + } +} + +object GeoParquetReadSupport extends Logging { + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist in + * `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = + clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema, caseSensitive) + if (clippedParquetFields.isEmpty) { + ParquetSchemaConverter.EMPTY_MESSAGE + } else { + Types + .buildMessage() + .addFields(clippedParquetFields: _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + } + + private def clipParquetType( + parquetType: Type, + catalystType: DataType, + caseSensitive: Boolean): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) + + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) + + case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. + */ + private def clipParquetListType( + parquetList: GroupType, + elementType: DataType, + caseSensitive: Boolean): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType, caseSensitive) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList + .getType(0) + .isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if (repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple") { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], + * or a [[StructType]]. + */ + private def clipParquetMapType( + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A clipped [[GroupType]], which has at least one field. + * @note + * Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return + * A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, + structType: StructType, + caseSensitive: Boolean): Seq[Type] = { + val toParquet = new SparkToGeoParquetSchemaConverter(writeLegacyParquetFormat = false) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException( + s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + } + .getOrElse(toParquet.convertField(f)) + } + } + } + + /** + * Computes the structural intersection between two Parquet group types. This is used to create + * a requestedSchema for ReadContext of Parquet-MR reader. Parquet-MR reader does not support + * the nested field access to non-existent field while parquet library does support to read the + * non-existent field by regular field access. + */ + private def intersectParquetGroups( + groupType1: GroupType, + groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + val field2 = groupType2.getType(field1.getName) + if (field2.isPrimitive) { + None + } else { + intersectParquetGroups(field1, field2.asGroupType) + } + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + + def expandUDT(schema: StructType): StructType = { + def expand(dataType: DataType): DataType = { + dataType match { + case t: ArrayType => + t.copy(elementType = expand(t.elementType)) + + case t: MapType => + t.copy(keyType = expand(t.keyType), valueType = expand(t.valueType)) + + case t: StructType => + val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType))) + t.copy(fields = expandedFields) + + // Don't expand GeometryUDT types. We'll treat geometry columns specially in + // GeoParquetRowConverter + case t: GeometryUDT => t + + case t: UserDefinedType[_] => + t.sqlType + + case t => + t + } + } + + expand(schema).asInstanceOf[StructType] + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala new file mode 100644 index 0000000000..dedbb237b5 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.time.ZoneId +import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.types.StructType + +/** + * A [[RecordMaterializer]] for Catalyst rows. + * + * @param parquetSchema + * Parquet schema of the records to be read + * @param catalystSchema + * Catalyst schema of the rows to be constructed + * @param schemaConverter + * A Parquet-Catalyst schema converter that helps initializing row converters + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseSpec + * the specification of rebasing date/timestamp from Julian to Proleptic Gregorian calendar: + * mode + optional original time zone + * @param int96RebaseSpec + * the specification of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + */ +class GeoParquetRecordMaterializer( + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: GeoParquetToSparkSchemaConverter, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String]) + extends RecordMaterializer[InternalRow] { + private val rootConverter = new GeoParquetRowConverter( + schemaConverter, + parquetSchema, + catalystSchema, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + NoopUpdater) + + override def getCurrentRecord: InternalRow = rootConverter.currentRecord + + override def getRootConverter: GroupConverter = rootConverter +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala new file mode 100644 index 0000000000..2f2eea38cd --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala @@ -0,0 +1,745 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.{GroupType, OriginalType, Type} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.locationtech.jts.io.WKBReader + +import java.math.{BigDecimal, BigInteger} +import java.time.{ZoneId, ZoneOffset} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +/** + * A [[ParquetRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root + * converter. Take the following Parquet type as an example: + * {{{ + * message root { + * required int32 f1; + * optional group f2 { + * required double f21; + * optional binary f22 (utf8); + * } + * } + * }}} + * 5 converters will be created: + * + * - a root [[ParquetRowConverter]] for [[org.apache.parquet.schema.MessageType]] `root`, which + * contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.OriginalType.INT_32]] field `f1`, and + * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[ParquetPrimitiveConverter]] for required + * [[org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE]] field `f21`, and + * - a [[ParquetStringConverter]] for optional + * [[org.apache.parquet.schema.OriginalType.UTF8]] string field `f22` + * + * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have + * any "parent" container. + * + * @param schemaConverter + * A utility converter used to convert Parquet types to Catalyst types. + * @param parquetType + * Parquet schema of Parquet records + * @param catalystType + * Spark SQL schema that corresponds to the Parquet record type. User-defined types other than + * [[GeometryUDT]] should have been expanded. + * @param convertTz + * the optional time zone to convert to int96 data + * @param datetimeRebaseMode + * the mode of rebasing date/timestamp from Julian to Proleptic Gregorian calendar + * @param int96RebaseMode + * the mode of rebasing INT96 timestamp from Julian to Proleptic Gregorian calendar + * @param parameters + * Options for reading GeoParquet files. For example, if legacyMode is enabled or not. + * @param updater + * An updater which propagates converted field values to the parent container + */ +private[parquet] class GeoParquetRowConverter( + schemaConverter: GeoParquetToSparkSchemaConverter, + parquetType: GroupType, + catalystType: StructType, + convertTz: Option[ZoneId], + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + int96RebaseMode: LegacyBehaviorPolicy.Value, + parameters: Map[String, String], + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) + with Logging { + + assert( + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + + assert( + !catalystType.existsRecursively(t => + !t.isInstanceOf[GeometryUDT] && t.isInstanceOf[UserDefinedType[_]]), + s"""User-defined types in Catalyst schema should have already been expanded: + |${catalystType.prettyJson} + """.stripMargin) + + logDebug(s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) + + /** + * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates + * converted filed values to the `ordinal`-th cell in `currentRow`. + */ + private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater { + override def set(value: Any): Unit = row(ordinal) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(value: Short): Unit = row.setShort(ordinal, value) + override def setInt(value: Int): Unit = row.setInt(ordinal, value) + override def setLong(value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) + } + + private[this] val currentRow = new SpecificInternalRow(catalystType.map(_.dataType)) + + /** + * The [[InternalRow]] converted from an entire Parquet record. + */ + def currentRecord: InternalRow = currentRow + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(datetimeRebaseMode, "Parquet") + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInRead(int96RebaseMode, "Parquet INT96") + + // Converters for each field. + private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { + // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false + // to prevent throwing IllegalArgumentException when searching catalyst type's field index + val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) { + catalystType.fieldNames.zipWithIndex.toMap + } else { + CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap) + } + parquetType.getFields.asScala.map { parquetField => + val fieldIndex = catalystFieldNameToIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) + }.toArray + } + + // Updaters for each field. + private[this] val fieldUpdaters: Array[ParentContainerUpdater] = fieldConverters.map(_.updater) + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def end(): Unit = { + var i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).end() + i += 1 + } + updater.set(currentRow) + } + + override def start(): Unit = { + var i = 0 + val numFields = currentRow.numFields + while (i < numFields) { + currentRow.setNullAt(i) + i += 1 + } + i = 0 + while (i < fieldUpdaters.length) { + fieldUpdaters(i).start() + i += 1 + } + } + + /** + * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type + * `catalystType`. Converted values are handled by `updater`. + */ + private def newConverter( + parquetType: Type, + catalystType: DataType, + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { + + catalystType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => + new ParquetPrimitiveConverter(updater) + + case GeometryUDT => + if (parquetType.isPrimitive) { + new ParquetPrimitiveConverter(updater) { + override def addBinary(value: Binary): Unit = { + val wkbReader = new WKBReader() + val geom = wkbReader.read(value.getBytes) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + if (GeoParquetUtils.isLegacyMode(parameters)) { + new ParquetArrayConverter( + parquetType.asGroupType(), + ArrayType(ByteType, containsNull = false), + updater) { + override def end(): Unit = { + val wkbReader = new WKBReader() + val byteArray = currentArray.map(_.asInstanceOf[Byte]).toArray + val geom = wkbReader.read(byteArray) + updater.set(GeometryUDT.serialize(geom)) + } + } + } else { + throw new IllegalArgumentException( + s"Parquet type for geometry column is $parquetType. This parquet file could be written by " + + "Apache Sedona <= 1.3.1-incubating. Please use option(\"legacyMode\", \"true\") to read this file.") + } + } + + case ByteType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setByte(value.asInstanceOf[ByteType#InternalType]) + + override def addBinary(value: Binary): Unit = { + val bytes = value.getBytes + for (b <- bytes) { + updater.set(b) + } + } + } + + case ShortType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setShort(value.asInstanceOf[ShortType#InternalType]) + } + + // For INT32 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For INT64 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals + case t: DecimalType + if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || + parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => + new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + case t: DecimalType => + throw new RuntimeException( + s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " + + s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " + + "FIXED_LEN_BYTE_ARRAY, or BINARY.") + + case StringType => + new ParquetStringConverter(updater) + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(timestampRebaseFunc(value)) + } + } + + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + val micros = GeoDateTimeUtils.millisToMicros(value) + updater.setLong(timestampRebaseFunc(micros)) + } + } + + // INT96 timestamp doesn't have a logical type, here we check the physical type instead. + case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 => + new ParquetPrimitiveConverter(updater) { + // Converts nanosecond timestamps stored as INT96 + override def addBinary(value: Binary): Unit = { + val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = convertTz + .map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + updater.setLong(adjTime) + } + } + + case DateType => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + updater.set(dateRebaseFunc(value)) + } + } + + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + + case t: ArrayType => + new ParquetArrayConverter(parquetType.asGroupType(), t, updater) + + case t: MapType => + new ParquetMapConverter(parquetType.asGroupType(), t, updater) + + case t: StructType => + val wrappedUpdater = { + // SPARK-30338: avoid unnecessary InternalRow copying for nested structs: + // There are two cases to handle here: + // + // 1. Parent container is a map or array: we must make a deep copy of the mutable row + // because this converter may be invoked multiple times per Parquet input record + // (if the map or array contains multiple elements). + // + // 2. Parent container is a struct: we don't need to copy the row here because either: + // + // (a) all ancestors are structs and therefore no copying is required because this + // converter will only be invoked once per Parquet input record, or + // (b) some ancestor is struct that is nested in a map or array and that ancestor's + // converter will perform deep-copying (which will recursively copy this row). + if (updater.isInstanceOf[RowUpdater]) { + // `updater` is a RowUpdater, implying that the parent container is a struct. + updater + } else { + // `updater` is NOT a RowUpdater, implying that the parent container a map or array. + new ParentContainerUpdater { + override def set(value: Any): Unit = { + updater.set(value.asInstanceOf[SpecificInternalRow].copy()) // deep copy + } + } + } + } + new GeoParquetRowConverter( + schemaConverter, + parquetType.asGroupType(), + t, + convertTz, + datetimeRebaseMode, + int96RebaseMode, + parameters, + wrappedUpdater) + + case t => + throw new RuntimeException( + s"Unable to create Parquet converter for data type ${t.json} " + + s"whose Parquet type is $parquetType") + } + } + + /** + * Parquet converter for strings. A dictionary is used to minimize string decoding cost. + */ + private final class ParquetStringConverter(updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + private var expandedDictionary: Array[UTF8String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i => + UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes) + } + } + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + override def addBinary(value: Binary): Unit = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we + // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying + // it. + val buffer = value.toByteBuffer + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() + updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) + } + } + + /** + * Parquet converter for fixed-precision decimals. + */ + private abstract class ParquetDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { + + protected var expandedDictionary: Array[Decimal] = _ + + override def hasDictionarySupport: Boolean = true + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + // Converts decimals stored as INT32 + override def addInt(value: Int): Unit = { + addLong(value: Long) + } + + // Converts decimals stored as INT64 + override def addLong(value: Long): Unit = { + updater.set(decimalFromLong(value)) + } + + // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY + override def addBinary(value: Binary): Unit = { + updater.set(decimalFromBinary(value)) + } + + protected def decimalFromLong(value: Long): Decimal = { + Decimal(value, precision, scale) + } + + protected def decimalFromBinary(value: Binary): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + val unscaled = ParquetRowConverter.binaryToUnscaledLong(value) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) + } + } + } + + private class ParquetIntDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToInt(id).toLong) + } + } + } + + private class ParquetLongDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToLong(id)) + } + } + } + + private class ParquetBinaryDictionaryAwareDecimalConverter( + precision: Int, + scale: Int, + updater: ParentContainerUpdater) + extends ParquetDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromBinary(dictionary.decodeToBinary(id)) + } + } + } + + /** + * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard + * Parquet lists are represented as a 3-level group annotated by `LIST`: + * {{{ + * group (LIST) { <-- parquetSchema points here + * repeated group list { + * element; + * } + * } + * }}} + * The `parquetSchema` constructor argument points to the outermost group. + * + * However, before this representation is standardized, some Parquet libraries/tools also use + * some non-standard formats to represent list-like structures. Backwards-compatibility rules + * for handling these cases are described in Parquet format spec. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + */ + private class ParquetArrayConverter( + parquetSchema: GroupType, + catalystSchema: ArrayType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + protected[this] val currentArray: mutable.ArrayBuffer[Any] = ArrayBuffer.empty[Any] + + private[this] val elementConverter: Converter = { + val repeatedType = parquetSchema.getType(0) + val elementType = catalystSchema.elementType + + // At this stage, we're not sure whether the repeated field maps to the element type or is + // just the syntactic repeated group of the 3-level standard LIST layout. Take the following + // Parquet LIST-annotated group type as an example: + // + // optional group f (LIST) { + // repeated group list { + // optional group element { + // optional int32 element; + // } + // } + // } + // + // This type is ambiguous: + // + // 1. When interpreted as a standard 3-level layout, the `list` field is just the syntactic + // group, and the entire type should be translated to: + // + // ARRAY> + // + // 2. On the other hand, when interpreted as a non-standard 2-level layout, the `list` field + // represents the element type, and the entire type should be translated to: + // + // ARRAY>> + // + // Here we try to convert field `list` into a Catalyst type to see whether the converted type + // matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise, + // it's case 2. + val guessedElementType = schemaConverter.convertFieldWithGeo(repeatedType) + + if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) { + // If the repeated field corresponds to the element type, creates a new converter using the + // type of the repeated field. + newConverter( + repeatedType, + elementType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentArray += value + }) + } else { + // If the repeated field corresponds to the syntactic group in the standard 3-level Parquet + // LIST layout, creates a new converter using the only child field of the repeated field. + assert(!repeatedType.isPrimitive && repeatedType.asGroupType().getFieldCount == 1) + new ElementConverter(repeatedType.asGroupType().getType(0), elementType) + } + } + + override def getConverter(fieldIndex: Int): Converter = elementConverter + + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + + override def start(): Unit = currentArray.clear() + + /** Array element converter */ + private final class ElementConverter(parquetType: Type, catalystType: DataType) + extends GroupConverter { + + private var currentElement: Any = _ + + private[this] val converter = + newConverter( + parquetType, + catalystType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentElement = value + }) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = currentArray += currentElement + + override def start(): Unit = currentElement = null + } + } + + /** Parquet converter for maps */ + private final class ParquetMapConverter( + parquetType: GroupType, + catalystType: MapType, + updater: ParentContainerUpdater) + extends ParquetGroupConverter(updater) { + + private[this] val currentKeys = ArrayBuffer.empty[Any] + private[this] val currentValues = ArrayBuffer.empty[Any] + + private[this] val keyValueConverter = { + val repeatedType = parquetType.getType(0).asGroupType() + new KeyValueConverter( + repeatedType.getType(0), + repeatedType.getType(1), + catalystType.keyType, + catalystType.valueType) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override def end(): Unit = { + // The parquet map may contains null or duplicated map keys. When it happens, the behavior is + // undefined. + // TODO (SPARK-26174): disallow it with a config. + updater.set( + new ArrayBasedMapData( + new GenericArrayData(currentKeys.toArray), + new GenericArrayData(currentValues.toArray))) + } + + override def start(): Unit = { + currentKeys.clear() + currentValues.clear() + } + + /** Parquet converter for key-value pairs within the map. */ + private final class KeyValueConverter( + parquetKeyType: Type, + parquetValueType: Type, + catalystKeyType: DataType, + catalystValueType: DataType) + extends GroupConverter { + + private var currentKey: Any = _ + + private var currentValue: Any = _ + + private[this] val converters = Array( + // Converter for keys + newConverter( + parquetKeyType, + catalystKeyType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentKey = value + }), + + // Converter for values + newConverter( + parquetValueType, + catalystValueType, + new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = { + currentKeys += currentKey + currentValues += currentValue + } + + override def start(): Unit = { + currentKey = null + currentValue = null + } + } + } + + private trait RepeatedConverter { + private[this] val currentArray = ArrayBuffer.empty[Any] + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray.clear() + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter + with RepeatedConverter + with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = + elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter + with HasParentContainerUpdater + with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private[this] val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala new file mode 100644 index 0000000000..eab20875a6 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.checkConversionRequirement +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @param assumeBinaryIsString + * Whether unannotated BINARY fields should be assumed to be Spark SQL [[StringType]] fields. + * @param assumeInt96IsTimestamp + * Whether unannotated INT96 fields should be assumed to be Spark SQL [[TimestampType]] fields. + * @param parameters + * Options for reading GeoParquet files. + */ +class GeoParquetToSparkSchemaConverter( + keyValueMetaData: java.util.Map[String, String], + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + parameters: Map[String, String]) { + + private val geoParquetMetaData: GeoParquetMetaData = + GeoParquetUtils.parseGeoParquetMetaData(keyValueMetaData, parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: SQLConf, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + parameters = parameters) + + def this( + keyValueMetaData: java.util.Map[String, String], + conf: Configuration, + parameters: Map[String, String]) = this( + keyValueMetaData = keyValueMetaData, + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + parameters = parameters) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.asScala.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertFieldWithGeo(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertFieldWithGeo(field), nullable = false) + + case REPEATED => + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertFieldWithGeo(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) + } + } + + StructType(fields.toSeq) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertFieldWithGeo(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def isGeometryField(fieldName: String): Boolean = + geoParquetMetaData.columns.contains(fieldName) + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotSupported() = + throw new IllegalArgumentException(s"Parquet type not supported: $typeString") + + def typeNotImplemented() = + throw new IllegalArgumentException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new IllegalArgumentException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + ParquetSchemaConverter.checkConversionRequirement( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + originalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + originalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case UINT_64 => typeNotSupported() + case TIMESTAMP_MICROS => TimestampType + case TIMESTAMP_MILLIS => TimestampType + case _ => illegalType() + } + + case INT96 => + ParquetSchemaConverter.checkConversionRequirement( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + originalType match { + case UTF8 | ENUM | JSON => StringType + case null if isGeometryField(field.getName) => GeometryUDT + case null if assumeBinaryIsString => StringType + case null => BinaryType + case BSON => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + originalType match { + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1, + s"Invalid list type $field") + + val repeatedType = field.getType(0) + ParquetSchemaConverter.checkConversionRequirement( + repeatedType.isRepetition(REPEATED), + s"Invalid list type $field") + + if (isElementTypeWithGeo(repeatedType, field.getName)) { + ArrayType(convertFieldWithGeo(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertFieldWithGeo(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + ParquetSchemaConverter.checkConversionRequirement( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertFieldWithGeo(keyType), + convertFieldWithGeo(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new IllegalArgumentException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + def isElementTypeWithGeo(repeatedType: Type, parentName: String): Boolean = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // ARRAY (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // ARRAY> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } +} + +/** + * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]]. + * + * @param writeLegacyParquetFormat + * Whether to use legacy Parquet format compatible with Spark 1.4 and prior versions when + * converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. When set to false, use + * standard format defined in parquet-format spec. This argument only affects Parquet write + * path. + * @param outputTimestampType + * which parquet timestamp type to use when writing. + */ +class SparkToGeoParquetSchemaConverter( + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96) + extends SparkToParquetSchemaConverter(writeLegacyParquetFormat, outputTimestampType) { + + def this(conf: SQLConf) = this( + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + outputTimestampType = conf.parquetOutputTimestampType) + + def this(conf: Configuration) = this( + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, + outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + override def convert(catalystSchema: StructType): MessageType = { + Types + .buildMessage() + .addFields(catalystSchema.map(convertField): _*) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + override def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + GeoParquetSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or + // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the + // behavior same as before. + // + // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond + // timestamp in Impala for some historical reasons. It's not recommended to be used for any + // other types and will probably be deprecated in some future version of parquet-format spec. + // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and + // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. + // + // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting + // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can + // store a timestamp into a `Long`. This design decision is subject to change though, for + // example, we may resort to nanosecond precision in the future. + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + Types.primitive(INT96, repetition).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + } + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ====================== + // Decimals (legacy mode) + // ====================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // ======================== + // Decimals (standard mode) + // ======================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(Decimal.minBytesForPrecision(precision)) + .named(field.name) + + // =================================== + // ArrayType and MapType (legacy mode) + // =================================== + + // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level + // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro + // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element + // field name "array" is borrowed from parquet-avro. + case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => + // group (LIST) { + // optional group bag { + // repeated array; + // } + // } + + // This should not use `listOfElements` here because this new method checks if the + // element name is `element` in the `GroupType` and throws an exception if not. + // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with + // `array` as its element name as below. Therefore, we build manually + // the correct group type here via the builder. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .buildGroup(REPEATED) + // "array" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable))) + .named("bag")) + .named(field.name) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => + // group (LIST) { + // repeated element; + // } + + // Here too, we should not use `listOfElements`. (See SPARK-16777) + Types + .buildGroup(repetition) + .as(LIST) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + .addField(convertField(StructField("array", elementType, nullable), REPEATED)) + .named(field.name) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ===================================== + // ArrayType and MapType (standard mode) + // ===================================== + + case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition) + .as(LIST) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition) + .as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields + .foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + } + .named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new IllegalArgumentException( + s"Unsupported data type ${field.dataType.catalogString}") + } + } +} + +private[sql] object GeoParquetSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + checkConversionRequirement( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ").trim) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala new file mode 100644 index 0000000000..477d744441 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +import scala.language.existentials + +object GeoParquetUtils { + def inferSchema( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val parquetOptions = new ParquetOptions(parameters, sparkSession.sessionState.conf) + val shouldMergeSchemas = parquetOptions.mergeSchema + val mergeRespectSummaries = sparkSession.sessionState.conf.isParquetSchemaRespectSummaries + val filesByType = splitFiles(files) + val filesToTouch = + if (shouldMergeSchemas) { + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq.empty + } else { + filesByType.data + } + needMerged ++ filesByType.metadata ++ filesByType.commonMetadata + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + filesByType.commonMetadata.headOption + // Falls back to "_metadata" + .orElse(filesByType.metadata.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(filesByType.data.headOption) + .toSeq + } + GeoParquetFileFormat.mergeSchemasInParallel(parameters, filesToTouch, sparkSession) + } + + case class FileTypes( + data: Seq[FileStatus], + metadata: Seq[FileStatus], + commonMetadata: Seq[FileStatus]) + + private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = { + val leaves = allFiles.toArray.sortBy(_.getPath.toString) + + FileTypes( + data = leaves.filterNot(f => isSummaryFile(f.getPath)), + metadata = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE), + commonMetadata = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)) + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + /** + * Legacy mode option is for reading Parquet files written by old versions of Apache Sedona (<= + * 1.3.1-incubating). Such files are actually not GeoParquet files and do not have GeoParquet + * file metadata. Geometry fields were encoded as list of bytes and stored as group type in + * Parquet files. The Definition of GeometryUDT before 1.4.0 was: + * {{{ + * case class GeometryUDT extends UserDefinedType[Geometry] { + * override def sqlType: DataType = ArrayType(ByteType, containsNull = false) + * // ... + * }}} + * Since 1.4.0, the sqlType of GeometryUDT is changed to BinaryType. This is a breaking change + * for reading old Parquet files. To read old Parquet files, users need to use "geoparquet" + * format and set legacyMode to true. + * @param parameters + * user provided parameters for reading GeoParquet files using `.option()` method, e.g. + * `spark.read.format("geoparquet").option("legacyMode", "true").load("path")` + * @return + * true if legacyMode is set to true, false otherwise + */ + def isLegacyMode(parameters: Map[String, String]): Boolean = + parameters.getOrElse("legacyMode", "false").toBoolean + + /** + * Parse GeoParquet file metadata from Parquet file metadata. Legacy parquet files do not + * contain GeoParquet file metadata, so we'll simply return an empty GeoParquetMetaData object + * when legacy mode is enabled. + * @param keyValueMetaData + * Parquet file metadata + * @param parameters + * user provided parameters for reading GeoParquet files + * @return + * GeoParquetMetaData object + */ + def parseGeoParquetMetaData( + keyValueMetaData: java.util.Map[String, String], + parameters: Map[String, String]): GeoParquetMetaData = { + val isLegacyMode = GeoParquetUtils.isLegacyMode(parameters) + GeoParquetMetaData.parseKeyValueMetaData(keyValueMetaData).getOrElse { + if (isLegacyMode) { + GeoParquetMetaData(None, "", Map.empty) + } else { + throw new IllegalArgumentException("GeoParquet file does not contain valid geo metadata") + } + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala new file mode 100644 index 0000000000..90d6d962f4 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.FinalizedWriteContext +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.io.api.Binary +import org.apache.parquet.io.api.RecordConsumer +import org.apache.sedona.common.utils.GeomUtils +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData.{GEOPARQUET_COVERING_KEY, GEOPARQUET_CRS_KEY, GEOPARQUET_VERSION_KEY, VERSION, createCoveringColumnMetadata} +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetWriteSupport.GeometryColumnInfo +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ +import org.json4s.{DefaultFormats, Extraction, JValue} +import org.json4s.jackson.JsonMethods.parse +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKBWriter + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * A Parquet [[WriteSupport]] implementation that writes Catalyst [[InternalRow]]s as Parquet + * messages. This class can write Parquet data in two modes: + * + * - Standard mode: Parquet data are written in standard format defined in parquet-format spec. + * - Legacy mode: Parquet data are written in legacy format compatible with Spark 1.4 and prior. + * + * This behavior can be controlled by SQL option `spark.sql.parquet.writeLegacyFormat`. The value + * of this option is propagated to this class by the `init()` method and its Hadoop configuration + * argument. + */ +class GeoParquetWriteSupport extends WriteSupport[InternalRow] with Logging { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. + // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access + // data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // Schema of the `InternalRow`s to be written + private var schema: StructType = _ + + // `ValueWriter`s for all fields of the schema + private var rootFieldWriters: Array[ValueWriter] = _ + + // The Parquet `RecordConsumer` to which all `InternalRow`s are written + private var recordConsumer: RecordConsumer = _ + + // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions + private var writeLegacyParquetFormat: Boolean = _ + + // Which parquet timestamp type to use when writing. + private var outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = _ + + // Reusable byte array used to write timestamps as Parquet INT96 values + private val timestampBuffer = new Array[Byte](12) + + // Reusable byte array used to write decimal values + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) + + private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_WRITE)) + + private val dateRebaseFunc = + GeoDataSourceUtils.creteDateRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val timestampRebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(datetimeRebaseMode, "Parquet") + + private val int96RebaseMode = LegacyBehaviorPolicy.withName( + SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_WRITE)) + + private val int96RebaseFunc = + GeoDataSourceUtils.creteTimestampRebaseFuncInWrite(int96RebaseMode, "Parquet INT96") + + // A mapping from geometry field ordinal to bounding box. According to the geoparquet specification, + // "Geometry columns MUST be at the root of the schema", so we don't need to worry about geometry + // fields in nested structures. + private val geometryColumnInfoMap: mutable.Map[Int, GeometryColumnInfo] = mutable.Map.empty + + private var geoParquetVersion: Option[String] = None + private var defaultGeoParquetCrs: Option[JValue] = None + private val geoParquetColumnCrsMap: mutable.Map[String, Option[JValue]] = mutable.Map.empty + private val geoParquetColumnCoveringMap: mutable.Map[String, Covering] = mutable.Map.empty + + override def init(configuration: Configuration): WriteContext = { + val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) + this.schema = StructType.fromString(schemaString) + this.writeLegacyParquetFormat = { + // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation + assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) + configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean + } + + this.outputTimestampType = { + val key = SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key + assert(configuration.get(key) != null) + SQLConf.ParquetOutputTimestampType.withName(configuration.get(key)) + } + + this.rootFieldWriters = schema.zipWithIndex + .map { case (field, ordinal) => + makeWriter(field.dataType, Some(ordinal)) + } + .toArray[ValueWriter] + + if (geometryColumnInfoMap.isEmpty) { + throw new RuntimeException("No geometry column found in the schema") + } + + geoParquetVersion = configuration.get(GEOPARQUET_VERSION_KEY) match { + case null => Some(VERSION) + case version: String => Some(version) + } + defaultGeoParquetCrs = configuration.get(GEOPARQUET_CRS_KEY) match { + case null => + // If no CRS is specified, we write null to the crs metadata field. This is for compatibility with + // geopandas 0.10.0 and earlier versions, which requires crs field to be present. + Some(org.json4s.JNull) + case "" => None + case crs: String => Some(parse(crs)) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_CRS_KEY + "." + name)).foreach { + case "" => geoParquetColumnCrsMap.put(name, None) + case crs: String => geoParquetColumnCrsMap.put(name, Some(parse(crs))) + } + } + Option(configuration.get(GEOPARQUET_COVERING_KEY)).foreach { coveringColumnName => + if (geometryColumnInfoMap.size > 1) { + throw new IllegalArgumentException( + s"$GEOPARQUET_COVERING_KEY is ambiguous when there are multiple geometry columns." + + s"Please specify $GEOPARQUET_COVERING_KEY. for configured geometry column.") + } + val geometryColumnName = schema(geometryColumnInfoMap.keys.head).name + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(geometryColumnName, covering) + } + geometryColumnInfoMap.keys.map(schema(_).name).foreach { name => + Option(configuration.get(GEOPARQUET_COVERING_KEY + "." + name)).foreach { + coveringColumnName => + val covering = createCoveringColumnMetadata(coveringColumnName, schema) + geoParquetColumnCoveringMap.put(name, covering) + } + } + + val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) + val sparkSqlParquetRowMetadata = GeoParquetWriteSupport.getSparkSqlParquetRowMetadata(schema) + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> sparkSqlParquetRowMetadata) ++ { + if (datetimeRebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyDateTime" -> "") + } else { + None + } + } ++ { + if (int96RebaseMode == LegacyBehaviorPolicy.LEGACY) { + Some("org.apache.spark.legacyINT96" -> "") + } else { + None + } + } + + logInfo(s"""Initialized Parquet WriteSupport with Catalyst schema: + |${schema.prettyJson} + |and corresponding Parquet message type: + |$messageType + """.stripMargin) + + new WriteContext(messageType, metadata.asJava) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + + override def finalizeWrite(): WriteSupport.FinalizedWriteContext = { + val metadata = new util.HashMap[String, String]() + if (geometryColumnInfoMap.nonEmpty) { + val primaryColumnIndex = geometryColumnInfoMap.keys.head + val primaryColumn = schema.fields(primaryColumnIndex).name + val columns = geometryColumnInfoMap.map { case (ordinal, columnInfo) => + val columnName = schema.fields(ordinal).name + val geometryTypes = columnInfo.seenGeometryTypes.toSeq + val bbox = if (geometryTypes.nonEmpty) { + Seq( + columnInfo.bbox.minX, + columnInfo.bbox.minY, + columnInfo.bbox.maxX, + columnInfo.bbox.maxY) + } else Seq(0.0, 0.0, 0.0, 0.0) + val crs = geoParquetColumnCrsMap.getOrElse(columnName, defaultGeoParquetCrs) + val covering = geoParquetColumnCoveringMap.get(columnName) + columnName -> GeometryFieldMetaData("WKB", geometryTypes, bbox, crs, covering) + }.toMap + val geoParquetMetadata = GeoParquetMetaData(geoParquetVersion, primaryColumn, columns) + val geoParquetMetadataJson = GeoParquetMetaData.toJson(geoParquetMetadata) + metadata.put("geo", geoParquetMetadataJson) + } + new FinalizedWriteContext(metadata) + } + + override def write(row: InternalRow): Unit = { + consumeMessage { + writeFields(row, schema, rootFieldWriters) + } + } + + private def writeFields( + row: InternalRow, + schema: StructType, + fieldWriters: Array[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + if (!row.isNullAt(i)) { + consumeField(schema(i).name, i) { + fieldWriters(i).apply(row, i) + } + } + i += 1 + } + } + + private def makeWriter(dataType: DataType, rootOrdinal: Option[Int] = None): ValueWriter = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getShort(ordinal)) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(dateRebaseFunc(row.getInt(ordinal))) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addInteger(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => recordConsumer.addFloat(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addDouble(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary( + Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) + + case TimestampType => + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + (row: SpecializedGetters, ordinal: Int) => + val micros = int96RebaseFunc(row.getLong(ordinal)) + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(micros) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + recordConsumer.addLong(timestampRebaseFunc(micros)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + (row: SpecializedGetters, ordinal: Int) => + val micros = row.getLong(ordinal) + val millis = GeoDateTimeUtils.microsToMillis(timestampRebaseFunc(micros)) + recordConsumer.addLong(millis) + } + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal))) + + case DecimalType.Fixed(precision, scale) => + makeDecimalWriter(precision, scale) + + case t: StructType => + val fieldWriters = t.map(_.dataType).map(makeWriter(_, None)).toArray[ValueWriter] + (row: SpecializedGetters, ordinal: Int) => + consumeGroup { + writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) + } + + case t: ArrayType => makeArrayWriter(t) + + case t: MapType => makeMapWriter(t) + + case GeometryUDT => + val geometryColumnInfo = rootOrdinal match { + case Some(ordinal) => + geometryColumnInfoMap.getOrElseUpdate(ordinal, new GeometryColumnInfo()) + case None => null + } + (row: SpecializedGetters, ordinal: Int) => { + val serializedGeometry = row.getBinary(ordinal) + val geom = GeometryUDT.deserialize(serializedGeometry) + val wkbWriter = new WKBWriter(GeomUtils.getDimension(geom)) + recordConsumer.addBinary(Binary.fromReusedByteArray(wkbWriter.write(geom))) + if (geometryColumnInfo != null) { + geometryColumnInfo.update(geom) + } + } + + case t: UserDefinedType[_] => makeWriter(t.sqlType) + + // TODO Adds IntervalType support + case _ => sys.error(s"Unsupported data type $dataType.") + } + } + + private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { + assert( + precision <= DecimalType.MAX_PRECISION, + s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") + + val numBytes = Decimal.minBytesForPrecision(precision) + + val int32Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addInteger(unscaledLong.toInt) + } + + val int64Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addLong(unscaledLong) + } + + val binaryWriterUsingUnscaledLong = + (row: SpecializedGetters, ordinal: Int) => { + // When the precision is low enough (<= 18) to squeeze the decimal value into a `Long`, we + // can build a fixed-length byte array with length `numBytes` using the unscaled `Long` + // value and the `decimalBuffer` for better performance. + val unscaled = row.getDecimal(ordinal, precision, scale).toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + + while (i < numBytes) { + decimalBuffer(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(decimalBuffer, 0, numBytes)) + } + + val binaryWriterUsingUnscaledBytes = + (row: SpecializedGetters, ordinal: Int) => { + val decimal = row.getDecimal(ordinal, precision, scale) + val bytes = decimal.toJavaBigDecimal.unscaledValue().toByteArray + val fixedLengthBytes = if (bytes.length == numBytes) { + // If the length of the underlying byte array of the unscaled `BigInteger` happens to be + // `numBytes`, just reuse it, so that we don't bother copying it to `decimalBuffer`. + bytes + } else { + // Otherwise, the length must be less than `numBytes`. In this case we copy contents of + // the underlying bytes with padding sign bytes to `decimalBuffer` to form the result + // fixed-length byte array. + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + + recordConsumer.addBinary(Binary.fromReusedByteArray(fixedLengthBytes, 0, numBytes)) + } + + writeLegacyParquetFormat match { + // Standard mode, 1 <= precision <= 9, writes as INT32 + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer + + // Standard mode, 10 <= precision <= 18, writes as INT64 + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer + + // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong + + // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY + case _ => binaryWriterUsingUnscaledBytes + } + } + + def makeArrayWriter(arrayType: ArrayType): ValueWriter = { + val elementWriter = makeWriter(arrayType.elementType) + + def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < array.numElements()) { + consumeGroup { + // Only creates the element field if the current array element is not null. + if (!array.isNullAt(i)) { + consumeField(elementFieldName, 0) { + elementWriter.apply(array, i) + } + } + } + i += 1 + } + } + } + } + } + + def twoLevelArrayWriter(repeatedFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedFieldName, 0) { + var i = 0 + while (i < array.numElements()) { + elementWriter.apply(array, i) + i += 1 + } + } + } + } + } + + (writeLegacyParquetFormat, arrayType.containsNull) match { + case (legacyMode @ false, _) => + // Standard mode: + // + // group (LIST) { + // repeated group list { + // ^~~~ repeatedGroupName + // element; + // ^~~~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "list", elementFieldName = "element") + + case (legacyMode @ true, nullableElements @ true) => + // Legacy mode, with nullable elements: + // + // group (LIST) { + // optional group bag { + // ^~~ repeatedGroupName + // repeated array; + // ^~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "bag", elementFieldName = "array") + + case (legacyMode @ true, nullableElements @ false) => + // Legacy mode, with non-nullable elements: + // + // group (LIST) { + // repeated array; + // ^~~~~ repeatedFieldName + // } + twoLevelArrayWriter(repeatedFieldName = "array") + } + } + + private def makeMapWriter(mapType: MapType): ValueWriter = { + val keyWriter = makeWriter(mapType.keyType) + val valueWriter = makeWriter(mapType.valueType) + val repeatedGroupName = if (writeLegacyParquetFormat) { + // Legacy mode: + // + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // ^~~ repeatedGroupName + // required key; + // value; + // } + // } + "map" + } else { + // Standard mode: + // + // group (MAP) { + // repeated group key_value { + // ^~~~~~~~~ repeatedGroupName + // required key; + // value; + // } + // } + "key_value" + } + + (row: SpecializedGetters, ordinal: Int) => { + val map = row.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + + consumeGroup { + // Only creates the repeated field if the map is non-empty. + if (map.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < map.numElements()) { + consumeGroup { + consumeField("key", 0) { + keyWriter.apply(keyArray, i) + } + + // Only creates the "value" field if the value if non-empty + if (!map.valueArray().isNullAt(i)) { + consumeField("value", 1) { + valueWriter.apply(valueArray, i) + } + } + } + i += 1 + } + } + } + } + } + } + + private def consumeMessage(f: => Unit): Unit = { + recordConsumer.startMessage() + f + recordConsumer.endMessage() + } + + private def consumeGroup(f: => Unit): Unit = { + recordConsumer.startGroup() + f + recordConsumer.endGroup() + } + + private def consumeField(field: String, index: Int)(f: => Unit): Unit = { + recordConsumer.startField(field, index) + f + recordConsumer.endField(field, index) + } +} + +object GeoParquetWriteSupport { + class GeometryColumnInfo { + val bbox: GeometryColumnBoundingBox = new GeometryColumnBoundingBox() + + // GeoParquet column metadata has a `geometry_types` property, which contains a list of geometry types + // that are present in the column. + val seenGeometryTypes: mutable.Set[String] = mutable.Set.empty + + def update(geom: Geometry): Unit = { + bbox.update(geom) + // In case of 3D geometries, a " Z" suffix gets added (e.g. ["Point Z"]). + val hasZ = { + val coordinate = geom.getCoordinate + if (coordinate != null) !coordinate.getZ.isNaN else false + } + val geometryType = if (!hasZ) geom.getGeometryType else geom.getGeometryType + " Z" + seenGeometryTypes.add(geometryType) + } + } + + class GeometryColumnBoundingBox( + var minX: Double = Double.PositiveInfinity, + var minY: Double = Double.PositiveInfinity, + var maxX: Double = Double.NegativeInfinity, + var maxY: Double = Double.NegativeInfinity) { + def update(geom: Geometry): Unit = { + val env = geom.getEnvelopeInternal + minX = math.min(minX, env.getMinX) + minY = math.min(minY, env.getMinY) + maxX = math.max(maxX, env.getMaxX) + maxY = math.max(maxY, env.getMaxY) + } + } + + private def getSparkSqlParquetRowMetadata(schema: StructType): String = { + val fields = schema.fields.map { field => + field.dataType match { + case _: GeometryUDT => + // Don't write the GeometryUDT type to the Parquet metadata. Write the type as binary for maximum + // compatibility. + field.copy(dataType = BinaryType) + case _ => field + } + } + StructType(fields).json + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala new file mode 100644 index 0000000000..aadca3a60f --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +// Needed by Sedona to support Spark 3.0 - 3.3 +object GeoSchemaMergeUtils { + + def mergeSchemasInParallel( + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus], + schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType]) + : Option[StructType] = { + val serializedConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(parameters)) + + // !! HACK ALERT !! + // Here is a hack for Parquet, but it can be used by Orc as well. + // + // Parquet requires `FileStatus`es to read footers. + // Here we try to send cached `FileStatus`es to executor side to avoid fetching them again. + // However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = files.map(f => (f.getPath.toString, f.getLen)) + + // Set the number of partitions to prevent following schema reads from generating many tasks + // in case of a small number of orc files. + val numParallelism = Math.min( + Math.max(partialFileStatusInfo.size, 1), + sparkSession.sparkContext.defaultParallelism) + + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + + // Issues a Spark job to read Parquet/ORC schema in parallel. + val partiallyMergedSchemas = + sparkSession.sparkContext + .parallelize(partialFileStatusInfo, numParallelism) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + val schemas = schemaReader(fakeFileStatuses, serializedConf.value, ignoreCorruptFiles) + + if (schemas.isEmpty) { + Iterator.empty + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergedSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } + } + .collect() + + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { + case cause: SparkException => + throw new SparkException(s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala new file mode 100644 index 0000000000..43e1ababb7 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataDataSource.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data source for reading GeoParquet metadata. This could be accessed using the `spark.read` + * interface: + * {{{ + * val df = spark.read.format("geoparquet.metadata").load("path/to/geoparquet") + * }}} + */ +class GeoParquetMetadataDataSource extends FileDataSourceV2 with DataSourceRegister { + override val shortName: String = "geoparquet.metadata" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + None, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + GeoParquetMetadataTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala new file mode 100644 index 0000000000..1fe2faa2e0 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataPartitionReaderFactory.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.{compact, render} + +case class GeoParquetMetadataPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + filters: Seq[Filter]) + extends FilePartitionReaderFactory { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val iter = GeoParquetMetadataPartitionReaderFactory.readFile( + broadcastedConf.value.value, + partitionedFile, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFile.partitionValues) + } +} + +object GeoParquetMetadataPartitionReaderFactory { + private def readFile( + configuration: Configuration, + partitionedFile: PartitionedFile, + readDataSchema: StructType): Iterator[InternalRow] = { + val filePath = partitionedFile.filePath + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath), configuration)) + .getFooter + .getFileMetaData + .getKeyValueMetaData + val row = GeoParquetMetaData.parseKeyValueMetaData(metadata) match { + case Some(geo) => + val geoColumnsMap = geo.columns.map { case (columnName, columnMetadata) => + implicit val formats: org.json4s.Formats = DefaultFormats + import org.json4s.jackson.Serialization + val columnMetadataFields: Array[Any] = Array( + UTF8String.fromString(columnMetadata.encoding), + new GenericArrayData(columnMetadata.geometryTypes.map(UTF8String.fromString).toArray), + new GenericArrayData(columnMetadata.bbox.toArray), + columnMetadata.crs + .map(projjson => UTF8String.fromString(compact(render(projjson)))) + .getOrElse(UTF8String.fromString("")), + columnMetadata.covering + .map(covering => UTF8String.fromString(Serialization.write(covering))) + .orNull) + val columnMetadataStruct = new GenericInternalRow(columnMetadataFields) + UTF8String.fromString(columnName) -> columnMetadataStruct + } + val fields: Array[Any] = Array( + UTF8String.fromString(filePath), + UTF8String.fromString(geo.version.orNull), + UTF8String.fromString(geo.primaryColumn), + ArrayBasedMapData(geoColumnsMap)) + new GenericInternalRow(fields) + case None => + // Not a GeoParquet file, return a row with null metadata values. + val fields: Array[Any] = Array(UTF8String.fromString(filePath), null, null, null) + new GenericInternalRow(fields) + } + Iterator(pruneBySchema(row, GeoParquetMetadataTable.schema, readDataSchema)) + } + + private def pruneBySchema( + row: InternalRow, + schema: StructType, + readDataSchema: StructType): InternalRow = { + // Projection push down for nested fields is not enabled, so this very simple implementation is enough. + val values: Array[Any] = readDataSchema.fields.map { field => + val index = schema.fieldIndex(field.name) + row.get(index, field.dataType) + } + new GenericInternalRow(values) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala new file mode 100644 index 0000000000..b86ab7a399 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ + +case class GeoParquetMetadataScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + GeoParquetMetadataPartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + pushedFilters) + } + + override def getFileUnSplittableReason(path: Path): String = + "Reading parquet file metadata does not require splitting the file" + + // This is for compatibility with Spark 3.0. Spark 3.3 does not have this method + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala new file mode 100644 index 0000000000..6a25e4530c --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataScanBuilder.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class GeoParquetMetadataScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + GeoParquetMetadataScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + getPushedDataFilters, + getPartitionFilters, + getDataFilters) + } + + // The following methods uses reflection to address compatibility issues for Spark 3.0 ~ 3.2 + + private def getPushedDataFilters: Array[Filter] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("pushedDataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Array[Filter]] + } catch { + case _: NoSuchFieldException => + Array.empty + } + } + + private def getPartitionFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("partitionFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } + + private def getDataFilters: Seq[Expression] = { + try { + val field = classOf[FileScanBuilder].getDeclaredField("dataFilters") + field.setAccessible(true) + field.get(this).asInstanceOf[Seq[Expression]] + } catch { + case _: NoSuchFieldException => + Seq.empty + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala new file mode 100644 index 0000000000..845764fae5 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/spark/sql/execution/datasources/v2/geoparquet/metadata/GeoParquetMetadataTable.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class GeoParquetMetadataTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def formatName: String = "GeoParquet Metadata" + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(GeoParquetMetadataTable.schema) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new GeoParquetMetadataScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) +} + +object GeoParquetMetadataTable { + private val columnMetadataType = StructType( + Seq( + StructField("encoding", StringType, nullable = true), + StructField("geometry_types", ArrayType(StringType), nullable = true), + StructField("bbox", ArrayType(DoubleType), nullable = true), + StructField("crs", StringType, nullable = true), + StructField("covering", StringType, nullable = true))) + + private val columnsType = MapType(StringType, columnMetadataType, valueContainsNull = false) + + val schema: StructType = StructType( + Seq( + StructField("path", StringType, nullable = false), + StructField("version", StringType, nullable = true), + StructField("primary_column", StringType, nullable = true), + StructField("columns", columnsType, nullable = true))) +} diff --git a/spark/spark-3.3/src/test/resources/log4j2.properties b/spark/spark-3.3/src/test/resources/log4j2.properties new file mode 100644 index 0000000000..5f89859463 --- /dev/null +++ b/spark/spark-3.3/src/test/resources/log4j2.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = File + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.sparkproject.jetty +logger.jetty.level = warn diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala new file mode 100644 index 0000000000..421890c700 --- /dev/null +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.Row +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.scalatest.BeforeAndAfterAll + +import java.util.Collections +import scala.collection.JavaConverters._ + +class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation: String = resourceFolder + "geoparquet/" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + describe("GeoParquet Metadata tests") { + it("Reading GeoParquet Metadata") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.collect() + assert(metadataArray.length > 1) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("version") == "1.0.0-dev")) + assert(metadataArray.exists(_.getAs[String]("primary_column") == "geometry")) + assert(metadataArray.exists { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + columnsMap != null && columnsMap + .containsKey("geometry") && columnsMap.get("geometry").isInstanceOf[Row] + }) + assert(metadataArray.forall { row => + val columnsMap = row.getJavaMap(row.fieldIndex("columns")) + if (columnsMap == null || !columnsMap.containsKey("geometry")) true + else { + val columnMetadata = columnsMap.get("geometry").asInstanceOf[Row] + columnMetadata.getAs[String]("encoding") == "WKB" && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("bbox")) + .asScala + .forall(_.isInstanceOf[Double]) && + columnMetadata + .getList[Any](columnMetadata.fieldIndex("geometry_types")) + .asScala + .forall(_.isInstanceOf[String]) && + columnMetadata.getAs[String]("crs").nonEmpty && + columnMetadata.getAs[String]("crs") != "null" + } + }) + } + + it("Reading GeoParquet Metadata with column pruning") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df + .selectExpr("path", "substring(primary_column, 1, 2) AS partial_primary_column") + .collect() + assert(metadataArray.length > 1) + assert(metadataArray.forall(_.length == 2)) + assert(metadataArray.exists(_.getAs[String]("path").endsWith(".parquet"))) + assert(metadataArray.exists(_.getAs[String]("partial_primary_column") == "ge")) + } + + it("Reading GeoParquet Metadata of plain parquet files") { + val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) + val metadataArray = df.where("path LIKE '%plain.parquet'").collect() + assert(metadataArray.nonEmpty) + assert(metadataArray.forall(_.getAs[String]("path").endsWith("plain.parquet"))) + assert(metadataArray.forall(_.getAs[String]("version") == null)) + assert(metadataArray.forall(_.getAs[String]("primary_column") == null)) + assert(metadataArray.forall(_.getAs[String]("columns") == null)) + } + + it("Read GeoParquet without CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "") + } + + it("Read GeoParquet with null CRS") { + val df = sparkSession.read + .format("geoparquet") + .load(geoparquetdatalocation + "/example-1.0.0-beta.1.parquet") + val geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + assert(metadata.getAs[String]("crs") == "null") + } + + it("Read GeoParquet with snake_case geometry column name and camelCase column name") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column_1", GeometryUDT, nullable = false), + StructField("geomColumn2", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_column_name_styles.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + val dfMeta = sparkSession.read.format("geoparquet.metadata").load(geoParquetSavePath) + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")) + assert(metadata.containsKey("geom_column_1")) + assert(!metadata.containsKey("geoColumn1")) + assert(metadata.containsKey("geomColumn2")) + assert(!metadata.containsKey("geom_column2")) + assert(!metadata.containsKey("geom_column_2")) + } + + it("Read GeoParquet with covering metadata") { + val dfMeta = sparkSession.read + .format("geoparquet.metadata") + .load(geoparquetdatalocation + "/example-1.1.0.parquet") + val row = dfMeta.collect()(0) + val metadata = row.getJavaMap(row.fieldIndex("columns")).get("geometry").asInstanceOf[Row] + val covering = metadata.getAs[String]("covering") + assert(covering.nonEmpty) + Seq("bbox", "xmin", "ymin", "xmax", "ymax").foreach { key => + assert(covering contains key) + } + } + } +} diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala new file mode 100644 index 0000000000..8f3cc3f1e5 --- /dev/null +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.generateTestData +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.readGeoParquetMetaDataMap +import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.writeTestDataAsGeoParquet +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.locationtech.jts.geom.Coordinate +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.geom.GeometryFactory +import org.scalatest.prop.TableDrivenPropertyChecks + +import java.io.File +import java.nio.file.Files + +class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrivenPropertyChecks { + + val tempDir: String = + Files.createTempDirectory("sedona_geoparquet_test_").toFile.getAbsolutePath + val geoParquetDir: String = tempDir + "/geoparquet" + var df: DataFrame = _ + var geoParquetDf: DataFrame = _ + var geoParquetMetaDataMap: Map[Int, Seq[GeoParquetMetaData]] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + df = generateTestData(sparkSession) + writeTestDataAsGeoParquet(df, geoParquetDir) + geoParquetDf = sparkSession.read.format("geoparquet").load(geoParquetDir) + geoParquetMetaDataMap = readGeoParquetMetaDataMap(geoParquetDir) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) + + describe("GeoParquet spatial filter push down tests") { + it("Push down ST_Contains") { + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Contains(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Covers") { + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", + Seq.empty) + testFilter("ST_Covers(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Covers(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + it("Push down ST_Within") { + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Within(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_Within(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_CoveredBy") { + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", + Seq(1)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", + Seq(0)) + testFilter( + "ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_CoveredBy(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", + Seq(3)) + testFilter( + "ST_CoveredBy(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", + Seq.empty) + } + + it("Push down ST_Intersects") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", + Seq.empty) + testFilter("ST_Intersects(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq(3)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + + it("Push down ST_Equals") { + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((-16 -16, -16 -14, -14 -14, -14 -16, -16 -16))'))", + Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-15 -15)'))", Seq(2)) + testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-16 -16)'))", Seq(2)) + testFilter( + "ST_Equals(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", + Seq.empty) + } + + forAll(Table("<", "<=")) { op => + it(s"Push down ST_Distance $op d") { + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 1", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 5", Seq.empty) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (3 4)')) $op 1", Seq(1)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 7.1", Seq(0, 1, 2, 3)) + testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) $op 1", Seq(2)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 2", + Seq.empty) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 3", + Seq(0, 1, 2, 3)) + testFilter( + s"ST_Distance(geom, ST_GeomFromText('LINESTRING (17 17, 18 18)')) $op 1", + Seq(1)) + } + } + + it("Push down And(filters...)") { + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1)) + testFilter( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", + Seq(3)) + } + + it("Push down Or(filters...)") { + testFilter( + "ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom) OR ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", + Seq(0, 1)) + testFilter( + "ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) <= 1 OR ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(1, 2)) + } + + it("Ignore negated spatial filters") { + testFilter( + "NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) AND NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(3)) + testFilter( + "ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) OR NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", + Seq(0, 1, 2, 3)) + } + + it("Mixed spatial filter with other filter") { + testFilter( + "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", + Seq(1, 3)) + } + } + + /** + * Test filter push down using specified query condition, and verify if the pushed down filter + * prunes regions as expected. We'll also verify the correctness of query results. + * @param condition + * SQL query condition + * @param expectedPreservedRegions + * Regions that should be preserved after filter push down + */ + private def testFilter(condition: String, expectedPreservedRegions: Seq[Int]): Unit = { + val dfFiltered = geoParquetDf.where(condition) + val preservedRegions = getPushedDownSpatialFilter(dfFiltered) match { + case Some(spatialFilter) => resolvePreservedRegions(spatialFilter) + case None => (0 until 4) + } + assert(expectedPreservedRegions == preservedRegions) + val expectedResult = + df.where(condition).orderBy("region", "id").select("region", "id").collect() + val actualResult = dfFiltered.orderBy("region", "id").select("region", "id").collect() + assert(expectedResult sameElements actualResult) + } + + private def getPushedDownSpatialFilter(df: DataFrame): Option[GeoParquetSpatialFilter] = { + val executedPlan = df.queryExecution.executedPlan + val fileSourceScanExec = executedPlan.find(_.isInstanceOf[FileSourceScanExec]) + assert(fileSourceScanExec.isDefined) + val fileFormat = fileSourceScanExec.get.asInstanceOf[FileSourceScanExec].relation.fileFormat + assert(fileFormat.isInstanceOf[GeoParquetFileFormat]) + fileFormat.asInstanceOf[GeoParquetFileFormat].spatialFilter + } + + private def resolvePreservedRegions(spatialFilter: GeoParquetSpatialFilter): Seq[Int] = { + geoParquetMetaDataMap + .filter { case (_, metaDataList) => + metaDataList.exists(metadata => spatialFilter.evaluate(metadata.columns)) + } + .keys + .toSeq + } +} + +object GeoParquetSpatialFilterPushDownSuite { + case class TestDataItem(id: Int, region: Int, geom: Geometry) + + /** + * Generate test data centered at (0, 0). The entire dataset was divided into 4 quadrants, each + * with a unique region ID. The dataset contains 4 points and 4 polygons in each quadrant. + * @param sparkSession + * SparkSession object + * @return + * DataFrame containing test data + */ + def generateTestData(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + val regionCenters = Seq((-10, 10), (10, 10), (-10, -10), (10, -10)) + val testData = regionCenters.zipWithIndex.flatMap { case ((x, y), i) => + generateTestDataForRegion(i, x, y) + } + testData.toDF() + } + + private def generateTestDataForRegion(region: Int, centerX: Double, centerY: Double) = { + val factory = new GeometryFactory() + val points = Seq( + factory.createPoint(new Coordinate(centerX - 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY + 5)), + factory.createPoint(new Coordinate(centerX - 5, centerY - 5)), + factory.createPoint(new Coordinate(centerX + 5, centerY - 5))) + val polygons = points.map { p => + val envelope = p.getEnvelopeInternal + envelope.expandBy(1) + factory.toGeometry(envelope) + } + (points ++ polygons).zipWithIndex.map { case (g, i) => TestDataItem(i, region, g) } + } + + /** + * Write the test dataframe as GeoParquet files. Each region is written to a separate file. + * We'll test spatial filter push down by examining which regions were preserved/pruned by + * evaluating the pushed down spatial filters + * @param testData + * dataframe containing test data + * @param path + * path to write GeoParquet files + */ + def writeTestDataAsGeoParquet(testData: DataFrame, path: String): Unit = { + testData.coalesce(1).write.partitionBy("region").format("geoparquet").save(path) + } + + /** + * Load GeoParquet metadata for each region. Note that there could be multiple files for each + * region, thus each region ID was associated with a list of GeoParquet metadata. + * @param path + * path to directory containing GeoParquet files + * @return + * Map of region ID to list of GeoParquet metadata + */ + def readGeoParquetMetaDataMap(path: String): Map[Int, Seq[GeoParquetMetaData]] = { + (0 until 4).map { k => + val geoParquetMetaDataSeq = readGeoParquetMetaDataByRegion(path, k) + k -> geoParquetMetaDataSeq + }.toMap + } + + private def readGeoParquetMetaDataByRegion( + geoParquetSavePath: String, + region: Int): Seq[GeoParquetMetaData] = { + val parquetFiles = new File(geoParquetSavePath + s"/region=$region") + .listFiles() + .filter(_.getName.endsWith(".parquet")) + parquetFiles.flatMap { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + GeoParquetMetaData.parseKeyValueMetaData(metadata) + } + } +} diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala new file mode 100644 index 0000000000..2da12eceb0 --- /dev/null +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.log4j.{Level, Logger} +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.DataFrame +import org.scalatest.{BeforeAndAfterAll, FunSpec} + +trait TestBaseScala extends FunSpec with BeforeAndAfterAll { + Logger.getRootLogger().setLevel(Level.WARN) + Logger.getLogger("org.apache").setLevel(Level.WARN) + Logger.getLogger("com").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + + val warehouseLocation = System.getProperty("user.dir") + "/target/" + val sparkSession = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + .config("spark.sql.warehouse.dir", warehouseLocation) + // We need to be explicit about broadcasting in tests. + .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .getOrCreate() + + val resourceFolder = System.getProperty("user.dir") + "/../common/src/test/resources/" + + override def beforeAll(): Unit = { + SedonaContext.create(sparkSession) + } + + override def afterAll(): Unit = { + // SedonaSQLRegistrator.dropAll(spark) + // spark.stop + } + + def loadCsv(path: String): DataFrame = { + sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) + } +} diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala new file mode 100644 index 0000000000..ccfd560c84 --- /dev/null +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala @@ -0,0 +1,748 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.Row +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.datasources.parquet.{Covering, GeoParquetMetaData, ParquetReadSupport} +import org.apache.spark.sql.functions.{col, expr} +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point, ST_PolygonFromEnvelope} +import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.json4s.jackson.parseJson +import org.locationtech.jts.geom.Geometry +import org.locationtech.jts.io.WKTReader +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.util.Collections +import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConverters._ + +class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll { + val geoparquetdatalocation1: String = resourceFolder + "geoparquet/example1.parquet" + val geoparquetdatalocation2: String = resourceFolder + "geoparquet/example2.parquet" + val geoparquetdatalocation3: String = resourceFolder + "geoparquet/example3.parquet" + val geoparquetdatalocation4: String = resourceFolder + "geoparquet/example-1.0.0-beta.1.parquet" + val geoparquetdatalocation5: String = resourceFolder + "geoparquet/example-1.1.0.parquet" + val legacyparquetdatalocation: String = + resourceFolder + "parquet/legacy-parquet-nested-columns.snappy.parquet" + val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(geoparquetoutputlocation)) + + describe("GeoParquet IO tests") { + it("GEOPARQUET Test example1 i.e. naturalearth_lowers dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = df.collect()(0) + assert(rows.getAs[Long]("pop_est") == 920938) + assert(rows.getAs[String]("continent") == "Oceania") + assert(rows.getAs[String]("name") == "Fiji") + assert(rows.getAs[String]("iso_a3") == "FJI") + assert(rows.getAs[Double]("gdp_md_est") == 8374.0) + assert( + rows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample1.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample1.parquet") + val newrows = df2.collect()(0) + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.28504 -17.72465, 177.67087 -17.381140000000002, 178.12557 -17.50481)), ((-179.79332010904864 -16.020882256741224, -179.9173693847653 -16.501783135649397, -180 -16.555216566639196, -180 -16.067132663642447, -179.79332010904864 -16.020882256741224)))") + } + it("GEOPARQUET Test example2 i.e. naturalearth_citie dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation2) + val rows = df.collect()(0) + assert(rows.getAs[String]("name") == "Vatican City") + assert( + rows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample2.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample2.parquet") + val newrows = df2.collect()(0) + assert(newrows.getAs[String]("name") == "Vatican City") + assert( + newrows + .getAs[Geometry]("geometry") + .toString == "POINT (12.453386544971766 41.903282179960115)") + } + it("GEOPARQUET Test example3 i.e. nybb dataset's Read and Write") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation3) + val rows = df.collect()(0) + assert(rows.getAs[Long]("BoroCode") == 5) + assert(rows.getAs[String]("BoroName") == "Staten Island") + assert(rows.getAs[Double]("Shape_Leng") == 330470.010332) + assert(rows.getAs[Double]("Shape_Area") == 1.62381982381e9) + assert(rows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + df.write + .format("geoparquet") + .mode(SaveMode.Overwrite) + .save(geoparquetoutputlocation + "/gp_sample3.parquet") + val df2 = sparkSession.read + .format("geoparquet") + .load(geoparquetoutputlocation + "/gp_sample3.parquet") + val newrows = df2.collect()(0) + assert( + newrows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022")) + } + it("GEOPARQUET Test example-1.0.0-beta.1.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample4.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + + val parquetFiles = + new File(geoParquetSavePath).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + assert(columnName == "geometry") + val geomTypes = (geo \ "columns" \ "geometry" \ "geometry_types").extract[Seq[String]] + assert(geomTypes.nonEmpty) + val sparkSqlRowMetadata = metadata.get(ParquetReadSupport.SPARK_METADATA_KEY) + assert(!sparkSqlRowMetadata.contains("GeometryUDT")) + } + } + it("GEOPARQUET Test example-1.1.0.parquet") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation5) + val count = df.count() + val rows = df.collect() + assert(rows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(count == rows.length) + + val geoParquetSavePath = geoparquetoutputlocation + "/gp_sample5.parquet" + df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val newRows = df2.collect() + assert(rows.length == newRows.length) + assert(newRows(0).getAs[AnyRef]("geometry").isInstanceOf[Geometry]) + assert(rows sameElements newRows) + } + + it("GeoParquet with multiple geometry columns") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")), + Row( + 2, + wktReader.read("POINT Z(1 2 3)"), + wktReader.read("POLYGON Z((0 0 2, 1 0 2, 1 1 2, 0 1 2, 0 0 2))")), + Row( + 3, + wktReader.read("MULTIPOINT (0 0, 1 1, 2 2)"), + wktReader.read("MULTILINESTRING ((0 0, 1 1), (2 2, 3 3))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + + // Find parquet files in geoParquetSavePath directory and validate their metadata + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + assert(version == GeoParquetMetaData.VERSION) + val g0Types = (geo \ "columns" \ "g0" \ "geometry_types").extract[Seq[String]] + val g1Types = (geo \ "columns" \ "g1" \ "geometry_types").extract[Seq[String]] + assert(g0Types.sorted == Seq("Point", "Point Z", "MultiPoint").sorted) + assert(g1Types.sorted == Seq("Polygon", "Polygon Z", "MultiLineString").sorted) + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == org.json4s.JNull) + } + + // Read GeoParquet with multiple geometry columns + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(df2.schema.fields(2).dataType.isInstanceOf[GeometryUDT]) + val rows = df2.collect() + assert(testData.length == rows.length) + assert(rows(0).getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(rows(0).getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + + it("GeoParquet save should work with empty dataframes") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/empty.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.schema.fields(1).dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val g0Types = (geo \ "columns" \ "g" \ "geometry_types").extract[Seq[String]] + val g0BBox = (geo \ "columns" \ "g" \ "bbox").extract[Seq[Double]] + assert(g0Types.isEmpty) + assert(g0BBox == Seq(0.0, 0.0, 0.0, 0.0)) + } + } + + it("GeoParquet save should work with snake_case column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geom_column", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/snake_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geom_column") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should work with camelCase column names") { + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("geomColumn", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(Collections.emptyList[Row](), schema) + val geoParquetSavePath = geoparquetoutputlocation + "/camel_case_column_name.parquet" + df.write.format("geoparquet").mode("overwrite").save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + val geomField = df2.schema.fields(1) + assert(geomField.name == "geomColumn") + assert(geomField.dataType.isInstanceOf[GeometryUDT]) + assert(0 == df2.count()) + } + + it("GeoParquet save should write user specified version and crs to geo metadata") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation4) + // This CRS is taken from https://proj.org/en/9.3/specifications/projjson.html#geographiccrs + // with slight modification. + val projjson = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + var geoParquetSavePath = geoparquetoutputlocation + "/gp_custom_meta.parquet" + df.write + .format("geoparquet") + .option("geoparquet.version", "10.9.8") + .option("geoparquet.crs", projjson) + .mode("overwrite") + .save(geoParquetSavePath) + val df2 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df2.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val version = (geo \ "version").extract[String] + val columnName = (geo \ "primary_column").extract[String] + assert(version == "10.9.8") + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs.isInstanceOf[org.json4s.JObject]) + assert(crs == parseJson(projjson)) + } + + // Setting crs to null explicitly + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_null.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "null") + .mode("overwrite") + .save(geoParquetSavePath) + val df3 = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + assert(df3.count() == df.count()) + + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNull) + } + + // Setting crs to "" to omit crs + geoParquetSavePath = geoparquetoutputlocation + "/gp_crs_omit.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val columnName = (geo \ "primary_column").extract[String] + val crs = geo \ "columns" \ columnName \ "crs" + assert(crs == org.json4s.JNothing) + } + } + + it("GeoParquet save should support specifying per-column CRS") { + val wktReader = new WKTReader() + val testData = Seq( + Row( + 1, + wktReader.read("POINT (1 2)"), + wktReader.read("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"))) + val schema = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("g0", GeometryUDT, nullable = false), + StructField("g1", GeometryUDT, nullable = false))) + val df = sparkSession.createDataFrame(testData.asJava, schema).repartition(1) + + val projjson0 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "NAD83(2011)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "NAD83 (National Spatial Reference System 2011)", + | "ellipsoid": { + | "name": "GRS 1980", + | "semi_major_axis": 6378137, + | "inverse_flattening": 298.257222101 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Horizontal component of 3D system.", + | "area": "Puerto Rico - onshore and offshore. United States (USA) onshore and offshore.", + | "bbox": { + | "south_latitude": 14.92, + | "west_longitude": 167.65, + | "north_latitude": 74.71, + | "east_longitude": -63.88 + | }, + | "id": { + | "authority": "EPSG", + | "code": 6318 + | } + |} + |""".stripMargin + + val projjson1 = + """ + |{ + | "$schema": "https://proj.org/schemas/v0.4/projjson.schema.json", + | "type": "GeographicCRS", + | "name": "Monte Mario (Rome)", + | "datum": { + | "type": "GeodeticReferenceFrame", + | "name": "Monte Mario (Rome)", + | "ellipsoid": { + | "name": "International 1924", + | "semi_major_axis": 6378388, + | "inverse_flattening": 297 + | }, + | "prime_meridian": { + | "name": "Rome", + | "longitude": 12.4523333333333 + | } + | }, + | "coordinate_system": { + | "subtype": "ellipsoidal", + | "axis": [ + | { + | "name": "Geodetic latitude", + | "abbreviation": "Lat", + | "direction": "north", + | "unit": "degree" + | }, + | { + | "name": "Geodetic longitude", + | "abbreviation": "Lon", + | "direction": "east", + | "unit": "degree" + | } + | ] + | }, + | "scope": "Geodesy, onshore minerals management.", + | "area": "Italy - onshore and offshore; San Marino, Vatican City State.", + | "bbox": { + | "south_latitude": 34.76, + | "west_longitude": 5.93, + | "north_latitude": 47.1, + | "east_longitude": 18.99 + | }, + | "id": { + | "authority": "EPSG", + | "code": 4806 + | } + |} + |""".stripMargin + + val geoParquetSavePath = geoparquetoutputlocation + "/multi_geoms_with_custom_crs.parquet" + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == parseJson(projjson1)) + } + + // Write without fallback CRS for g0 + df.write + .format("geoparquet") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNull) + assert(g1Crs == parseJson(projjson1)) + } + + // Fallback CRS is omitting CRS + df.write + .format("geoparquet") + .option("geoparquet.crs", "") + .option("geoparquet.crs.g1", projjson1) + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == org.json4s.JNothing) + assert(g1Crs == parseJson(projjson1)) + } + + // Write with CRS, explicitly set CRS to null for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "null") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNull) + } + + // Write with CRS, explicitly omit CRS for g1 + df.write + .format("geoparquet") + .option("geoparquet.crs", projjson0) + .option("geoparquet.crs.g1", "") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + val g0Crs = geo \ "columns" \ "g0" \ "crs" + val g1Crs = geo \ "columns" \ "g1" \ "crs" + assert(g0Crs == parseJson(projjson0)) + assert(g1Crs == org.json4s.JNothing) + } + } + + it("GeoParquet load should raise exception when loading plain parquet files") { + val e = intercept[SparkException] { + sparkSession.read.format("geoparquet").load(resourceFolder + "geoparquet/plain.parquet") + } + assert(e.getMessage.contains("does not contain valid geo metadata")) + } + + it("GeoParquet load with spatial predicates") { + val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1) + val rows = + df.where(ST_Intersects(ST_Point(35.174722, -6.552465), col("geometry"))).collect() + assert(rows.length == 1) + assert(rows(0).getAs[String]("name") == "Tanzania") + } + + it("Filter push down for nested columns") { + import sparkSession.implicits._ + + // Prepare multiple GeoParquet files with bbox metadata. There should be 10 files in total, each file contains + // 1000 records. + val dfIds = (0 until 10000).toDF("id") + val dfGeom = dfIds + .withColumn( + "bbox", + expr("struct(id as minx, id as miny, id + 1 as maxx, id + 1 as maxy)")) + .withColumn("geom", expr("ST_PolygonFromEnvelope(id, id, id + 1, id + 1)")) + .withColumn("part_id", expr("CAST(id / 1000 AS INTEGER)")) + .coalesce(1) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_bbox.parquet" + dfGeom.write + .partitionBy("part_id") + .format("geoparquet") + .mode("overwrite") + .save(geoParquetSavePath) + + val sparkListener = new SparkListener() { + val recordsRead = new AtomicLong(0) + + def reset(): Unit = recordsRead.set(0) + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead + this.recordsRead.getAndAdd(recordsRead) + } + } + + sparkSession.sparkContext.addSparkListener(sparkListener) + try { + val df = sparkSession.read.format("geoparquet").load(geoParquetSavePath) + + // This should trigger filter push down to Parquet and only read one of the files. The number of records read + // should be less than 1000. + df.where("bbox.minx > 6000 and bbox.minx < 6600").count() + assert(sparkListener.recordsRead.get() <= 1000) + + // Reading these files using spatial filter. This should only read two of the files. + sparkListener.reset() + df.where(ST_Intersects(ST_PolygonFromEnvelope(7010, 7010, 8100, 8100), col("geom"))) + .count() + assert(sparkListener.recordsRead.get() <= 2000) + } finally { + sparkSession.sparkContext.removeSparkListener(sparkListener) + } + } + + it("Ready legacy parquet files written by Apache Sedona <= 1.3.1-incubating") { + val df = sparkSession.read + .format("geoparquet") + .option("legacyMode", "true") + .load(legacyparquetdatalocation) + val rows = df.collect() + assert(rows.nonEmpty) + rows.foreach { row => + assert(row.getAs[AnyRef]("geom").isInstanceOf[Geometry]) + assert(row.getAs[AnyRef]("struct_geom").isInstanceOf[Row]) + val structGeom = row.getAs[Row]("struct_geom") + assert(structGeom.getAs[AnyRef]("g0").isInstanceOf[Geometry]) + assert(structGeom.getAs[AnyRef]("g1").isInstanceOf[Geometry]) + } + } + + it("GeoParquet supports writing covering metadata") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geometry", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geometry", "test_cov") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + val coveringJsValue = geo \ "columns" \ "geometry" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov", "ymax")) + } + } + + it("GeoParquet supports writing covering metadata for multiple columns") { + val df = sparkSession + .range(0, 100) + .toDF("id") + .withColumn("id", expr("CAST(id AS DOUBLE)")) + .withColumn("geom1", expr("ST_Point(id, id + 1)")) + .withColumn( + "test_cov1", + expr("struct(id AS xmin, id + 1 AS ymin, id AS xmax, id + 1 AS ymax)")) + .withColumn("geom2", expr("ST_Point(10 * id, 10 * id + 1)")) + .withColumn( + "test_cov2", + expr( + "struct(10 * id AS xmin, 10 * id + 1 AS ymin, 10 * id AS xmax, 10 * id + 1 AS ymax)")) + val geoParquetSavePath = geoparquetoutputlocation + "/gp_with_covering_metadata.parquet" + df.write + .format("geoparquet") + .option("geoparquet.covering.geom1", "test_cov1") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + Seq(("geom1", "test_cov1"), ("geom2", "test_cov2")).foreach { + case (geomName, coveringName) => + val coveringJsValue = geo \ "columns" \ geomName \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq(coveringName, "xmin")) + assert(covering.bbox.ymin == Seq(coveringName, "ymin")) + assert(covering.bbox.xmax == Seq(coveringName, "xmax")) + assert(covering.bbox.ymax == Seq(coveringName, "ymax")) + } + } + + df.write + .format("geoparquet") + .option("geoparquet.covering.geom2", "test_cov2") + .mode("overwrite") + .save(geoParquetSavePath) + validateGeoParquetMetadata(geoParquetSavePath) { geo => + implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats + assert(geo \ "columns" \ "geom1" \ "covering" == org.json4s.JNothing) + val coveringJsValue = geo \ "columns" \ "geom2" \ "covering" + val covering = coveringJsValue.extract[Covering] + assert(covering.bbox.xmin == Seq("test_cov2", "xmin")) + assert(covering.bbox.ymin == Seq("test_cov2", "ymin")) + assert(covering.bbox.xmax == Seq("test_cov2", "xmax")) + assert(covering.bbox.ymax == Seq("test_cov2", "ymax")) + } + } + } + + def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => Unit): Unit = { + val parquetFiles = new File(path).listFiles().filter(_.getName.endsWith(".parquet")) + parquetFiles.foreach { filePath => + val metadata = ParquetFileReader + .open(HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration())) + .getFooter + .getFileMetaData + .getKeyValueMetaData + assert(metadata.containsKey("geo")) + val geo = parseJson(metadata.get("geo")) + body(geo) + } + } +} From 7afd08bd5a224d9d072ba4737de1749e80ebb962 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Wed, 7 Aug 2024 21:40:15 -0700 Subject: [PATCH 2/8] fix github workflow --- .github/workflows/java.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 36e91ed46a..51ec765905 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -104,7 +104,7 @@ jobs: if [ ${SPARK_VERSION:2:1} -gt "3" ]; then SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} fi - mvn -q clean install -Dspark=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} + mvn -q clean install -Dspark.compat.version=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} - run: mkdir staging - run: cp spark-shaded/target/sedona-*.jar staging - run: | From ac08ca36c6b8a4ecafa0edd3551ee1099f259cfe Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 07:08:09 -0700 Subject: [PATCH 3/8] Update the spark profile versions --- pom.xml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index 9fbf5db5ec..f71e7b680e 100644 --- a/pom.xml +++ b/pom.xml @@ -674,17 +674,16 @@ - + sedona-spark-3.0 spark 3.0 - true - 3.0.0 + 3.0.3 3.0 2.17.2 @@ -697,10 +696,9 @@ spark 3.1 - true - 3.1.0 + 3.1.2 3.1 2.17.2 @@ -711,9 +709,8 @@ spark - 3.2 + 3.2.3 - true 3.2.0 @@ -722,7 +719,7 @@ - + sedona-spark-3.3 From 60ea5e37937a4206a22843c54407167f6d306784 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 07:29:38 -0700 Subject: [PATCH 4/8] fix the spark compact version assignment --- .github/workflows/java.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 51ec765905..ff70e32690 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -100,10 +100,7 @@ jobs: SCALA_VERSION: ${{ matrix.scala }} SKIP_TESTS: ${{ matrix.skipTests }} run: | - SPARK_COMPAT_VERSION="3.0" - if [ ${SPARK_VERSION:2:1} -gt "3" ]; then - SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} - fi + SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} mvn -q clean install -Dspark.compat.version=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} - run: mkdir staging - run: cp spark-shaded/target/sedona-*.jar staging From b496c2b559b4b07eb62144150d232e7a5795edc5 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 07:31:01 -0700 Subject: [PATCH 5/8] temporarily disable fail-fast --- .github/workflows/java.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index ff70e32690..5a7cf2b1d5 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -37,7 +37,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: true + fail-fast: false matrix: include: - spark: 3.5.0 From b7f213e4fdce3b02c4f68ad8b91f72a02a06ff55 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 17:13:55 -0700 Subject: [PATCH 6/8] revert mvn build property name --- .github/workflows/java.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 5a7cf2b1d5..7c3a552145 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -100,8 +100,9 @@ jobs: SCALA_VERSION: ${{ matrix.scala }} SKIP_TESTS: ${{ matrix.skipTests }} run: | + SPARK_COMPAT_VERSION="3.0" SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} - mvn -q clean install -Dspark.compat.version=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} + mvn -q clean install -Dspark=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} - run: mkdir staging - run: cp spark-shaded/target/sedona-*.jar staging - run: | From 0e95f5eafedf40fd61aa577e5328747420d6f24d Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 19:11:52 -0700 Subject: [PATCH 7/8] fix the compact version for python and r pipelines --- .github/workflows/java.yml | 1 - .github/workflows/python.yml | 5 +---- .github/workflows/r.yml | 5 +---- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 7c3a552145..bb8a372076 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -100,7 +100,6 @@ jobs: SCALA_VERSION: ${{ matrix.scala }} SKIP_TESTS: ${{ matrix.skipTests }} run: | - SPARK_COMPAT_VERSION="3.0" SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} mvn -q clean install -Dspark=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS} - run: mkdir staging diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index b712f0c357..e7d1002d94 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -111,10 +111,7 @@ jobs: SPARK_VERSION: ${{ matrix.spark }} SCALA_VERSION: ${{ matrix.scala }} run: | - SPARK_COMPAT_VERSION="3.0" - if [ ${SPARK_VERSION:2:1} -gt "3" ]; then - SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} - fi + SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} mvn -q clean install -DskipTests -Dspark=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} -Dgeotools - env: SPARK_VERSION: ${{ matrix.spark }} diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 2ec23c7706..2f0841e0ee 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -102,10 +102,7 @@ jobs: key: apache.sedona-apache-spark-${{ steps.os-name.outputs.os-name }}-${{ env.SPARK_VERSION }} - name: Build Sedona libraries run: | - SPARK_COMPAT_VERSION="3.0" - if [ ${SPARK_VERSION:2:1} -gt "3" ]; then - SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} - fi + SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3} mvn -q clean install -DskipTests -Dspark=${SPARK_COMPAT_VERSION} -Dscala=${SCALA_VERSION:0:4} - name: Run tests run: | From 70fdb9261e268111d465298aec790951a6ae644b Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 8 Aug 2024 19:13:17 -0700 Subject: [PATCH 8/8] revert fail-fast to true --- .github/workflows/java.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index bb8a372076..49ab88c954 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -37,7 +37,7 @@ jobs: build: runs-on: ubuntu-22.04 strategy: - fail-fast: false + fail-fast: true matrix: include: - spark: 3.5.0