From c7c0b086a0b18424725433ade840d5121ac2b86e Mon Sep 17 00:00:00 2001 From: James Yu Date: Fri, 15 Jun 2018 21:04:04 -0700 Subject: [PATCH 1/9] add one supported type missing from the javadoc ## What changes were proposed in this pull request? The supported java.math.BigInteger type is not mentioned in the javadoc of Encoders.bean() ## How was this patch tested? only Javadoc fix Please review http://spark.apache.org/contributing.html before opening a pull request. Author: James Yu Closes #21544 from yuj/master. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. From b0a935255951280b49c39968f6234163e2f0e379 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 18 Jun 2018 15:32:34 +0800 Subject: [PATCH 2/9] [SPARK-24573][INFRA] Runs SBT checkstyle after the build to work around a side-effect ## What changes were proposed in this pull request? Seems checkstyle affects the build in the PR builder in Jenkins. I can't reproduce in my local and seems it can only be reproduced in the PR builder. I was checking the places it goes through and this is just a speculation that checkstyle's compilation in SBT has a side effect to the assembly build. This PR proposes to run the SBT checkstyle after the build. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21579 from HyukjinKwon/investigate-javastyle. --- dev/run-tests.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5e8c8590b5c34..cd4590864b7d7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,11 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -603,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) From e219e692ef70c161f37a48bfdec2a94b29260004 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 19 Jun 2018 00:24:54 +0800 Subject: [PATCH 3/9] [SPARK-23772][SQL] Provide an option to ignore column of all null values or empty array during JSON schema inference ## What changes were proposed in this pull request? This pr added a new JSON option `dropFieldIfAllNull ` to ignore column of all null values or empty array/struct during JSON schema inference. ## How was this patch tested? Added tests in `JsonSuite`. Author: Takeshi Yamamuro Author: Xiangrui Meng Closes #20929 from maropu/SPARK-23772. --- python/pyspark/sql/readwriter.py | 5 +- .../spark/sql/catalyst/json/JSONOptions.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonInferSchema.scala | 40 +++++++-------- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 49 +++++++++++++++++++ 6 files changed, 80 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a0e20d39c20da..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2ff12acb2946f..c081772116f84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index de6be5f76e15a..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e7eed95a560a3..f6edc7bfb3750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -75,7 +75,7 @@ private[sql] object JsonInferSchema { // active SparkSession and `SQLConf.get` may point to the wrong configs. val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -181,33 +181,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..a8a4a524a97f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2427,4 +2427,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } } From bce177552564a4862bc979d39790cf553a477d74 Mon Sep 17 00:00:00 2001 From: trystanleftwich Date: Tue, 19 Jun 2018 00:34:24 +0800 Subject: [PATCH 4/9] [SPARK-24526][BUILD][TEST-MAVEN] Spaces in the build dir causes failures in the build/mvn script ## What changes were proposed in this pull request? Fix the call to ${MVN_BIN} to be wrapped in quotes so it will handle having spaces in the path. ## How was this patch tested? Ran the following to confirm using the build/mvn tool with a space in the build dir now works without error ``` mkdir /tmp/test\ spaces cd /tmp/test\ spaces git clone https://github.com/apache/spark.git cd spark # Remove all mvn references in PATH so the script will download mvn to the local dir ./build/mvn -DskipTests clean package ``` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: trystanleftwich Closes #21534 from trystanleftwich/SPARK-24526. --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index efa4f9364ea52..1405983982d4c 100755 --- a/build/mvn +++ b/build/mvn @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" From 8f225e055c2031ca85d61721ab712170ab4e50c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Jun 2018 11:01:17 -0700 Subject: [PATCH 5/9] [SPARK-24548][SQL] Fix incorrect schema of Dataset with tuple encoders ## What changes were proposed in this pull request? When creating tuple expression encoders, we should give the serializer expressions of tuple items correct names, so we can have correct output schema when we use such tuple encoders. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21576 from viirya/SPARK-24548. --- .../catalyst/encoders/ExpressionEncoder.scala | 3 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 18 ++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e3..093cee91d2f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 1737d45e08a5f1fb78515b14321721d7197b443a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jun 2018 20:15:01 -0700 Subject: [PATCH 6/9] [SPARK-24478][SQL][FOLLOWUP] Move projection and filter push down to physical conversion ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/21503, to completely move operator pushdown to the planner rule. The code are mostly from https://github.com/apache/spark/pull/21319 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21574 from cloud-fan/followup. --- .../v2/reader/SupportsReportStatistics.java | 7 +- .../datasources/v2/DataSourceV2Relation.scala | 109 ++++----------- .../datasources/v2/DataSourceV2Strategy.scala | 124 +++++++++++++----- 3 files changed, 123 insertions(+), 117 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index a79080a249ec8..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -23,10 +23,9 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. * - * Statistics are reported to the optimizer before a projection or any filters are pushed to the - * DataSourceReader. Implementations that return more accurate statistics based on projection and - * filters will not improve query performance until the planner can push operators before getting - * stats. + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index e08af218513fd..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,17 +23,24 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ @@ -42,14 +49,7 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - lazy val v2Options: DataSourceOptions = makeV2Options(options) - - def newReader: DataSourceReader = userSpecifiedSchema match { - case Some(userSchema) => - source.asReadSupportWithSchema.createReader(userSchema, v2Options) - case None => - source.asReadSupport.createReader(v2Options) - } + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => @@ -139,83 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, output, options, userSpecifiedSchema) - } - - def pushRequiredColumns( - relation: DataSourceV2Relation, - reader: DataSourceReader, - struct: StructType): Seq[AttributeReference] = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - // return the output columns from the relation that were projected - val attrMap = relation.output.map(a => a.name -> a).toMap - projectionSupport.readSchema().map(f => attrMap(f.name)) - case _ => - relation.output - } - } - - def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (postScanFilters, pushedFilters) - - case r: SupportsPushDownFilters => - // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] - // Catalyst filter expression that can't be translated to data source filters. - val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] - - for (filterExpr <- filters) { - val translated = DataSourceStrategy.translateFilter(filterExpr) - if (translated.isDefined) { - translatedFilterToExpr(translated.get) = filterExpr - } else { - untranslatableExprs += filterExpr - } - } - - // Data source filters that need to be evaluated again after scanning. which means - // the data source cannot guarantee the rows returned can pass these filters. - // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - - (untranslatableExprs ++ postScanFilters, pushedFilters) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 8bf858c38d76c..182aa2906cf1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,51 +17,115 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.{execution, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} object DataSourceV2Strategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(relation.output) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - relation.output - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - val reader = relation.newReader + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } - val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, - projection.asInstanceOf[Seq[AttributeReference]].toStructType) + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } - val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) val scan = DataSourceV2ScanExec( output, relation.source, relation.options, pushedFilters, reader) - val filter = postScanFilters.reduceLeftOption(And) - val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) val withProjection = if (withFilter.output != project) { - execution.ProjectExec(project, withFilter) + ProjectExec(project, withFilter) } else { withFilter } From 9a75c18290fff7d116cf88a44f9120bf67d8bd27 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 18 Jun 2018 20:17:04 -0700 Subject: [PATCH 7/9] [SPARK-24542][SQL] UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files ## What changes were proposed in this pull request? UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files. Spark does not have built-in access control. When users use the external access control library, users might bypass them and access the file contents. This PR basically patches the Hive fix to Apache Spark. https://issues.apache.org/jira/browse/HIVE-18879 ## How was this patch tested? A unit test case Author: Xiao Li Closes #21549 from gatorsmile/xpathSecurity. --- .../expressions/xml/UDFXPathUtil.java | 28 ++++++++++++++++++- .../expressions/xml/UDFXPathUtilSuite.scala | 21 ++++++++++++++ .../xml/XPathExpressionSuite.scala | 5 ++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } From a78a9046413255756653f70165520efd486fb493 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Jun 2018 10:42:08 -0700 Subject: [PATCH 8/9] [SPARK-24521][SQL][TEST] Fix ineffective test in CachedTableSuite ## What changes were proposed in this pull request? test("withColumn doesn't invalidate cached dataframe") in CachedTableSuite doesn't not work because: The UDF is executed and test count incremented when "df.cache()" is called and the subsequent "df.collect()" has no effect on the test result. This PR fixed this test and add another test for caching UDF. ## How was this patch tested? Add new tests. Author: Li Jin Closes #21531 from icexelloss/fix-cache-test. --- .../apache/spark/sql/CachedTableSuite.scala | 19 ---------- .../apache/spark/sql/DatasetCacheSuite.scala | 38 ++++++++++++++++++- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..6982c22f4771d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -83,25 +83,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..82a93f74dd76c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ test("get storage level") { @@ -96,4 +99,37 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(5 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } } From 9dbe53eb6bb5916d28000f2c0d646cf23094ac11 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 19 Jun 2018 10:52:51 -0700 Subject: [PATCH 9/9] [SPARK-24556][SQL] Always rewrite output partitioning in ReusedExchangeExec and InMemoryTableScanExec ## What changes were proposed in this pull request? Currently, ReusedExchange and InMemoryTableScanExec only rewrite output partitioning if child's partitioning is HashPartitioning and do nothing for other partitioning, e.g., RangePartitioning. We should always rewrite it, otherwise, unnecessary shuffle could be introduced like https://issues.apache.org/jira/browse/SPARK-24556. ## How was this patch tested? Add new tests. Author: yucai Closes #21564 from yucai/SPARK-24556. --- .../columnar/InMemoryTableScanExec.scala | 6 +- .../sql/execution/exchange/Exchange.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 64 ++++++++++++++++++- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 37d468739c613..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext { Range(1, 2, 1, 1))) df.queryExecution.executedPlan.execute() } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements