From d68f3a726ffb4280d85268ef5a13b408b123ff48 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 05:54:48 +0000 Subject: [PATCH 001/145] [SPARK-25676][FOLLOWUP][BUILD] Fix Scala 2.12 build error ## What changes were proposed in this pull request? This PR fixes the Scala-2.12 build. ## How was this patch tested? Manual build with Scala-2.12 profile. Closes #22970 from dongjoon-hyun/SPARK-25676-2.12. Authored-by: Dongjoon Hyun Signed-off-by: DB Tsai --- sql/core/benchmarks/WideTableBenchmark-results.txt | 14 +++++++------- .../execution/benchmark/WideTableBenchmark.scala | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt index 3b41a3e036c4d..7bc388aaa549f 100644 --- a/sql/core/benchmarks/WideTableBenchmark-results.txt +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -split threshold 10 38932 / 39307 0.0 37128.1 1.0X -split threshold 100 31991 / 32556 0.0 30508.8 1.2X -split threshold 1024 10993 / 11041 0.1 10483.5 3.5X -split threshold 2048 8959 / 8998 0.1 8543.8 4.3X -split threshold 4096 8116 / 8134 0.1 7739.8 4.8X -split threshold 8196 8069 / 8098 0.1 7695.5 4.8X -split threshold 65536 57068 / 57339 0.0 54424.3 0.7X +split threshold 10 39634 / 39829 0.0 37798.3 1.0X +split threshold 100 30121 / 30571 0.0 28725.8 1.3X +split threshold 1024 9678 / 9725 0.1 9229.9 4.1X +split threshold 2048 8634 / 8662 0.1 8233.6 4.6X +split threshold 4096 8561 / 8576 0.1 8164.6 4.6X +split threshold 8192 8393 / 8408 0.1 8003.8 4.7X +split threshold 65536 57063 / 57273 0.0 54419.1 0.7X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala index ffefef1d4fce3..c61db3ce4b949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf /** @@ -42,7 +43,7 @@ object WideTableBenchmark extends SqlBasedBenchmark { Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => benchmark.addCase(s"split threshold $n", numIters = 5) { iter => withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { - df.selectExpr(columns: _*).foreach(identity(_)) + df.selectExpr(columns: _*).foreach((x => x): Row => Unit) } } } From 17449a2e6b28ecce7a273284eab037e8aceb3611 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 14:48:23 +0800 Subject: [PATCH 002/145] [SPARK-25952][SQL] Passing actual schema to JacksonParser ## What changes were proposed in this pull request? The PR fixes an issue when the corrupt record column specified via `spark.sql.columnNameOfCorruptRecord` or JSON options `columnNameOfCorruptRecord` is propagated to JacksonParser, and returned row breaks an assumption in `FailureSafeParser` that the row must contain only actual data. The issue is fixed by passing actual schema without the corrupt record field into `JacksonParser`. ## How was this patch tested? Added a test with the corrupt record column in the middle of user's schema. Closes #22958 from MaxGekk/from_json-corrupt-record-schema. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/expressions/jsonExpressions.scala | 14 ++++++++------ .../org/apache/spark/sql/JsonFunctionsSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index eafcb6161036e..52d0677f4022f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -569,14 +569,16 @@ case class JsonToStructs( throw new IllegalArgumentException(s"from_json() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } - val rawParser = new JacksonParser(nullableSchema, parsedOptions, allowArrayAsStructs = false) - val createParser = CreateJacksonParser.utf8String _ - - val parserSchema = nullableSchema match { - case s: StructType => s - case other => StructType(StructField("value", other) :: Nil) + val (parserSchema, actualSchema) = nullableSchema match { + case s: StructType => + (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) + case other => + (StructType(StructField("value", other) :: Nil), other) } + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + new FailureSafeParser[UTF8String]( input => rawParser.parse(input, createParser, identity[UTF8String]), mode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 2b09782faeeaa..d6b73387e84b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -578,4 +578,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("corrupt record column in the middle") { + val schema = new StructType() + .add("a", IntegerType) + .add("_unparsed", StringType) + .add("b", IntegerType) + val badRec = """{"a" 1, "b": 11}""" + val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS() + + checkAnswer( + df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), + Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) + } } From ee03f760b305e70a57c3b4409ec25897af348600 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 14:51:29 +0800 Subject: [PATCH 003/145] [SPARK-25955][TEST] Porting JSON tests for CSV functions ## What changes were proposed in this pull request? In the PR, I propose to port existing JSON tests from `JsonFunctionsSuite` that are applicable for CSV, and put them to `CsvFunctionsSuite`. In particular: - roundtrip `from_csv` to `to_csv`, and `to_csv` to `from_csv` - using `schema_of_csv` in `from_csv` - Java API `from_csv` - using `from_csv` and `to_csv` in exprs. Closes #22960 from MaxGekk/csv-additional-tests. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../apache/spark/sql/CsvFunctionsSuite.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 1dd8ec31ee111..b97ac380def63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -117,4 +117,51 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("from_csv uses DDL strings for defining a schema - java") { + val df = Seq("""1,"haa"""").toDS() + checkAnswer( + df.select( + from_csv($"value", lit("a INT, b STRING"), new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + + test("roundtrip to_csv -> from_csv") { + val df = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") + val schema = df.schema(0).dataType.asInstanceOf[StructType] + val options = Map.empty[String, String] + val readback = df.select(to_csv($"struct").as("csv")) + .select(from_csv($"csv", schema, options).as("struct")) + + checkAnswer(df, readback) + } + + test("roundtrip from_csv -> to_csv") { + val df = Seq(Some("1"), None).toDF("csv") + val schema = new StructType().add("a", IntegerType) + val options = Map.empty[String, String] + val readback = df.select(from_csv($"csv", schema, options).as("struct")) + .select(to_csv($"struct").as("csv")) + + checkAnswer(df, readback) + } + + test("infers schemas of a CSV string and pass to to from_csv") { + val in = Seq("""0.123456789,987654321,"San Francisco"""").toDS() + val options = Map.empty[String, String].asJava + val out = in.select(from_csv('value, schema_of_csv("0.1,1,a"), options) as "parsed") + val expected = StructType(Seq(StructField( + "parsed", + StructType(Seq( + StructField("_c0", DoubleType, true), + StructField("_c1", IntegerType, true), + StructField("_c2", StringType, true)))))) + + assert(out.schema == expected) + } + + test("Support to_csv in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) + } } From 0a2e45fdb8baadf7a57eb06f319e96f95eedf298 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Nov 2018 16:32:25 +0800 Subject: [PATCH 004/145] Revert "[SPARK-23831][SQL] Add org.apache.derby to IsolatedClientLoader" This reverts commit a75571b46f813005a6d4b076ec39081ffab11844. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 1 - .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ------ 2 files changed, 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 1e7a0b187c8b3..c1d8fe53a9e8c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -186,7 +186,6 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 - name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 1de258f060943..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,10 +113,4 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } - - test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { - val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - assert(!client1.equals(client2)) - } } From a3004d084c654237c60d02df1507333b92b860c6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 03:40:28 -0800 Subject: [PATCH 005/145] [SPARK-25971][SQL] Ignore partition byte-size statistics in SQLQueryTestSuite ## What changes were proposed in this pull request? Currently, `SQLQueryTestSuite` is sensitive in terms of the bytes of parquet files in table partitions. If we change the default file format (from Parquet to ORC) or update the metadata of them, the test case should be changed accordingly. This PR aims to make `SQLQueryTestSuite` more robust by ignoring the partition byte statistics. ``` -Partition Statistics 1144 bytes, 2 rows +Partition Statistics [not included in comparison] bytes, 2 rows ``` ## How was this patch tested? Pass the Jenkins with the newly updated test cases. Closes #22972 from dongjoon-hyun/SPARK-25971. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../results/describe-part-after-analyze.sql.out | 12 ++++++------ .../org/apache/spark/sql/SQLQueryTestSuite.scala | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 8ba69c698b551..17dd317f63b70 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -93,7 +93,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -128,7 +128,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -155,7 +155,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1098 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -190,7 +190,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -217,7 +217,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1098 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -244,7 +244,7 @@ Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1144 bytes, 2 rows +Partition Statistics [not included in comparison] bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 826408c7161e9..6ca3ac596e5f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -272,6 +272,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. From 0d7396f3af2d4348ae53e6a274df952b7f17c37c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 Nov 2018 03:51:55 -0800 Subject: [PATCH 006/145] [SPARK-22827][SQL][FOLLOW-UP] Throw `SparkOutOfMemoryError` in `HashAggregateExec`, too. ## What changes were proposed in this pull request? This is a follow-up pr of #20014 which introduced `SparkOutOfMemoryError` to avoid killing the entire executor when an `OutOfMemoryError` is thrown. We should throw `SparkOutOfMemoryError` in `HashAggregateExec`, too. ## How was this patch tested? Existing tests. Closes #22969 from ueshin/issues/SPARK-22827/oome. Authored-by: Takuya UESHIN Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 25d8e7dff3d99..08dcdf33fb8f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.TaskContext -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -762,6 +762,8 @@ case class HashAggregateExec( ("true", "true", "", "") } + val oomeClassName = classOf[SparkOutOfMemoryError].getName + val findOrInsertRegularHashMap: String = s""" |// generate grouping key @@ -787,7 +789,7 @@ case class HashAggregateExec( | $unsafeRowKeys, ${hashEval.value}); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); + | throw new $oomeClassName("No enough memory for aggregation"); | } |} """.stripMargin From 6abe90625efeb8140531a875700e87ed7e981044 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 23:37:14 +0800 Subject: [PATCH 007/145] [SPARK-25676][SQL][FOLLOWUP] Use 'foreach(_ => ())' ## What changes were proposed in this pull request? #22970 fixed Scala 2.12 build error, and this PR updates the function according to the review comments. ## How was this patch tested? This is also manually tested with Scala 2.12 build. Closes #22978 from dongjoon-hyun/SPARK-25676-3. Authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- sql/core/benchmarks/WideTableBenchmark-results.txt | 14 +++++++------- .../execution/benchmark/WideTableBenchmark.scala | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt index 7bc388aaa549f..8c09f9ca11307 100644 --- a/sql/core/benchmarks/WideTableBenchmark-results.txt +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -split threshold 10 39634 / 39829 0.0 37798.3 1.0X -split threshold 100 30121 / 30571 0.0 28725.8 1.3X -split threshold 1024 9678 / 9725 0.1 9229.9 4.1X -split threshold 2048 8634 / 8662 0.1 8233.6 4.6X -split threshold 4096 8561 / 8576 0.1 8164.6 4.6X -split threshold 8192 8393 / 8408 0.1 8003.8 4.7X -split threshold 65536 57063 / 57273 0.0 54419.1 0.7X +split threshold 10 40571 / 40937 0.0 38691.7 1.0X +split threshold 100 31116 / 31669 0.0 29674.6 1.3X +split threshold 1024 10077 / 10199 0.1 9609.7 4.0X +split threshold 2048 8654 / 8692 0.1 8253.2 4.7X +split threshold 4096 8006 / 8038 0.1 7634.7 5.1X +split threshold 8192 8069 / 8107 0.1 7695.3 5.0X +split threshold 65536 56973 / 57204 0.0 54333.7 0.7X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala index c61db3ce4b949..52426d81bd1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf /** @@ -43,7 +42,7 @@ object WideTableBenchmark extends SqlBasedBenchmark { Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => benchmark.addCase(s"split threshold $n", numIters = 5) { iter => withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { - df.selectExpr(columns: _*).foreach((x => x): Row => Unit) + df.selectExpr(columns: _*).foreach(_ => ()) } } } From 7bb901aa28d3000c2e18cc769fe5769abd650770 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 8 Nov 2018 10:08:14 -0800 Subject: [PATCH 008/145] [SPARK-25964][SQL][MINOR] Revise OrcReadBenchmark/DataSourceReadBenchmark case names and execution instructions ## What changes were proposed in this pull request? 1. OrcReadBenchmark is under hive module, so the way to run it should be ``` build/sbt "hive/test:runMain " ``` 2. The benchmark "String with Nulls Scan" should be with case "String with Nulls Scan(5%/50%/95%)", not "(0.05%/0.5%/0.95%)" 3. Add the null value percentages in the test case names of DataSourceReadBenchmark, for the benchmark "String with Nulls Scan" . ## How was this patch tested? Re-run benchmarks Closes #22965 from gengliangwang/fixHiveOrcReadBenchmark. Lead-authored-by: Gengliang Wang Co-authored-by: Gengliang Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../DataSourceReadBenchmark-results.txt | 336 +++++++++--------- .../benchmark/DataSourceReadBenchmark.scala | 4 +- .../benchmarks/OrcReadBenchmark-results.txt | 170 ++++----- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 11 +- 4 files changed, 263 insertions(+), 258 deletions(-) diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index 2d3bae442cc50..b07e8b1197ff0 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,268 +2,268 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 21508 / 22112 0.7 1367.5 1.0X -SQL Json 8705 / 8825 1.8 553.4 2.5X -SQL Parquet Vectorized 157 / 186 100.0 10.0 136.7X -SQL Parquet MR 1789 / 1794 8.8 113.8 12.0X -SQL ORC Vectorized 156 / 166 100.9 9.9 138.0X -SQL ORC Vectorized with copy 218 / 225 72.1 13.9 98.6X -SQL ORC MR 1448 / 1492 10.9 92.0 14.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26366 / 26562 0.6 1676.3 1.0X +SQL Json 8709 / 8724 1.8 553.7 3.0X +SQL Parquet Vectorized 166 / 187 94.8 10.5 159.0X +SQL Parquet MR 1706 / 1720 9.2 108.4 15.5X +SQL ORC Vectorized 167 / 174 94.2 10.6 157.9X +SQL ORC Vectorized with copy 226 / 231 69.6 14.4 116.7X +SQL ORC MR 1433 / 1465 11.0 91.1 18.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 202 / 211 77.7 12.9 1.0X -ParquetReader Vectorized -> Row 118 / 120 133.5 7.5 1.7X +ParquetReader Vectorized 200 / 207 78.7 12.7 1.0X +ParquetReader Vectorized -> Row 117 / 119 134.7 7.4 1.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 23282 / 23312 0.7 1480.2 1.0X -SQL Json 9187 / 9189 1.7 584.1 2.5X -SQL Parquet Vectorized 204 / 218 77.0 13.0 114.0X -SQL Parquet MR 1941 / 1953 8.1 123.4 12.0X -SQL ORC Vectorized 217 / 225 72.6 13.8 107.5X -SQL ORC Vectorized with copy 279 / 289 56.3 17.8 83.4X -SQL ORC MR 1541 / 1549 10.2 98.0 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26489 / 26547 0.6 1684.1 1.0X +SQL Json 8990 / 8998 1.7 571.5 2.9X +SQL Parquet Vectorized 209 / 221 75.1 13.3 126.5X +SQL Parquet MR 1949 / 1949 8.1 123.9 13.6X +SQL ORC Vectorized 221 / 228 71.3 14.0 120.1X +SQL ORC Vectorized with copy 315 / 319 49.9 20.1 84.0X +SQL ORC MR 1527 / 1549 10.3 97.1 17.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 288 / 297 54.6 18.3 1.0X -ParquetReader Vectorized -> Row 255 / 257 61.7 16.2 1.1X +ParquetReader Vectorized 286 / 296 54.9 18.2 1.0X +ParquetReader Vectorized -> Row 249 / 253 63.1 15.8 1.1X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 24990 / 25012 0.6 1588.8 1.0X -SQL Json 9837 / 9865 1.6 625.4 2.5X -SQL Parquet Vectorized 170 / 180 92.3 10.8 146.6X -SQL Parquet MR 2319 / 2328 6.8 147.4 10.8X -SQL ORC Vectorized 293 / 301 53.7 18.6 85.3X -SQL ORC Vectorized with copy 297 / 309 52.9 18.9 84.0X -SQL ORC MR 1667 / 1674 9.4 106.0 15.0X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 27701 / 27744 0.6 1761.2 1.0X +SQL Json 9703 / 9733 1.6 616.9 2.9X +SQL Parquet Vectorized 176 / 182 89.2 11.2 157.0X +SQL Parquet MR 2164 / 2173 7.3 137.6 12.8X +SQL ORC Vectorized 307 / 314 51.2 19.5 90.2X +SQL ORC Vectorized with copy 312 / 319 50.4 19.8 88.7X +SQL ORC MR 1690 / 1700 9.3 107.4 16.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 257 / 274 61.3 16.3 1.0X -ParquetReader Vectorized -> Row 259 / 264 60.8 16.4 1.0X +ParquetReader Vectorized 259 / 277 60.7 16.5 1.0X +ParquetReader Vectorized -> Row 261 / 265 60.3 16.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 32537 / 32554 0.5 2068.7 1.0X -SQL Json 12610 / 12668 1.2 801.7 2.6X -SQL Parquet Vectorized 258 / 276 61.0 16.4 126.2X -SQL Parquet MR 2422 / 2435 6.5 154.0 13.4X -SQL ORC Vectorized 378 / 385 41.6 24.0 86.2X -SQL ORC Vectorized with copy 381 / 389 41.3 24.2 85.4X -SQL ORC MR 1797 / 1819 8.8 114.3 18.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 34813 / 34900 0.5 2213.3 1.0X +SQL Json 12570 / 12617 1.3 799.2 2.8X +SQL Parquet Vectorized 270 / 308 58.2 17.2 128.9X +SQL Parquet MR 2427 / 2431 6.5 154.3 14.3X +SQL ORC Vectorized 388 / 398 40.6 24.6 89.8X +SQL ORC Vectorized with copy 395 / 402 39.9 25.1 88.2X +SQL ORC MR 1819 / 1851 8.6 115.7 19.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 352 / 368 44.7 22.4 1.0X -ParquetReader Vectorized -> Row 351 / 359 44.8 22.3 1.0X +ParquetReader Vectorized 372 / 379 42.3 23.7 1.0X +ParquetReader Vectorized -> Row 357 / 368 44.1 22.7 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 27179 / 27184 0.6 1728.0 1.0X -SQL Json 12578 / 12585 1.3 799.7 2.2X -SQL Parquet Vectorized 161 / 171 97.5 10.3 168.5X -SQL Parquet MR 2361 / 2395 6.7 150.1 11.5X -SQL ORC Vectorized 473 / 480 33.3 30.0 57.5X -SQL ORC Vectorized with copy 478 / 483 32.9 30.4 56.8X -SQL ORC MR 1858 / 1859 8.5 118.2 14.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 28753 / 28781 0.5 1828.0 1.0X +SQL Json 12039 / 12215 1.3 765.4 2.4X +SQL Parquet Vectorized 170 / 177 92.4 10.8 169.0X +SQL Parquet MR 2184 / 2196 7.2 138.9 13.2X +SQL ORC Vectorized 432 / 440 36.4 27.5 66.5X +SQL ORC Vectorized with copy 439 / 442 35.9 27.9 65.6X +SQL ORC MR 1812 / 1833 8.7 115.2 15.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 251 / 255 62.7 15.9 1.0X -ParquetReader Vectorized -> Row 255 / 259 61.8 16.2 1.0X +ParquetReader Vectorized 253 / 260 62.2 16.1 1.0X +ParquetReader Vectorized -> Row 256 / 257 61.6 16.2 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 34797 / 34830 0.5 2212.3 1.0X -SQL Json 17806 / 17828 0.9 1132.1 2.0X -SQL Parquet Vectorized 260 / 269 60.6 16.5 134.0X -SQL Parquet MR 2512 / 2534 6.3 159.7 13.9X -SQL ORC Vectorized 582 / 593 27.0 37.0 59.8X -SQL ORC Vectorized with copy 576 / 584 27.3 36.6 60.4X -SQL ORC MR 2309 / 2313 6.8 146.8 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 36177 / 36188 0.4 2300.1 1.0X +SQL Json 18895 / 18898 0.8 1201.3 1.9X +SQL Parquet Vectorized 267 / 276 58.9 17.0 135.6X +SQL Parquet MR 2355 / 2363 6.7 149.7 15.4X +SQL ORC Vectorized 543 / 546 29.0 34.5 66.6X +SQL ORC Vectorized with copy 548 / 557 28.7 34.8 66.0X +SQL ORC MR 2246 / 2258 7.0 142.8 16.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 350 / 363 44.9 22.3 1.0X -ParquetReader Vectorized -> Row 350 / 366 44.9 22.3 1.0X +ParquetReader Vectorized 353 / 367 44.6 22.4 1.0X +ParquetReader Vectorized -> Row 351 / 357 44.7 22.3 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 22486 / 22590 0.5 2144.5 1.0X -SQL Json 14124 / 14195 0.7 1347.0 1.6X -SQL Parquet Vectorized 2342 / 2347 4.5 223.4 9.6X -SQL Parquet MR 4660 / 4664 2.2 444.4 4.8X -SQL ORC Vectorized 2378 / 2379 4.4 226.8 9.5X -SQL ORC Vectorized with copy 2548 / 2571 4.1 243.0 8.8X -SQL ORC MR 4206 / 4211 2.5 401.1 5.3X +SQL CSV 21130 / 21246 0.5 2015.1 1.0X +SQL Json 12145 / 12174 0.9 1158.2 1.7X +SQL Parquet Vectorized 2363 / 2377 4.4 225.3 8.9X +SQL Parquet MR 4555 / 4557 2.3 434.4 4.6X +SQL ORC Vectorized 2361 / 2388 4.4 225.1 9.0X +SQL ORC Vectorized with copy 2540 / 2557 4.1 242.2 8.3X +SQL ORC MR 4186 / 4209 2.5 399.2 5.0X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 12150 / 12178 0.9 1158.7 1.0X -SQL Json 7012 / 7014 1.5 668.7 1.7X -SQL Parquet Vectorized 792 / 796 13.2 75.5 15.3X -SQL Parquet MR 1961 / 1975 5.3 187.0 6.2X -SQL ORC Vectorized 482 / 485 21.8 46.0 25.2X -SQL ORC Vectorized with copy 710 / 715 14.8 67.7 17.1X -SQL ORC MR 2081 / 2083 5.0 198.5 5.8X +SQL CSV 11693 / 11729 0.9 1115.1 1.0X +SQL Json 7025 / 7025 1.5 669.9 1.7X +SQL Parquet Vectorized 803 / 821 13.1 76.6 14.6X +SQL Parquet MR 1776 / 1790 5.9 169.4 6.6X +SQL ORC Vectorized 491 / 494 21.4 46.8 23.8X +SQL ORC Vectorized with copy 723 / 725 14.5 68.9 16.2X +SQL ORC MR 2050 / 2063 5.1 195.5 5.7X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - CSV 31789 / 31791 0.5 2021.1 1.0X -Data column - Json 12873 / 12918 1.2 818.4 2.5X -Data column - Parquet Vectorized 267 / 280 58.9 17.0 119.1X -Data column - Parquet MR 3387 / 3402 4.6 215.3 9.4X -Data column - ORC Vectorized 391 / 453 40.2 24.9 81.2X -Data column - ORC Vectorized with copy 392 / 398 40.2 24.9 81.2X -Data column - ORC MR 2508 / 2512 6.3 159.4 12.7X -Partition column - CSV 6965 / 6977 2.3 442.8 4.6X -Partition column - Json 5563 / 5576 2.8 353.7 5.7X -Partition column - Parquet Vectorized 65 / 78 241.1 4.1 487.2X -Partition column - Parquet MR 1811 / 1811 8.7 115.1 17.6X -Partition column - ORC Vectorized 66 / 73 239.0 4.2 483.0X -Partition column - ORC Vectorized with copy 65 / 70 241.1 4.1 487.3X -Partition column - ORC MR 1775 / 1778 8.9 112.8 17.9X -Both columns - CSV 30032 / 30113 0.5 1909.4 1.1X -Both columns - Json 13941 / 13959 1.1 886.3 2.3X -Both columns - Parquet Vectorized 312 / 330 50.3 19.9 101.7X -Both columns - Parquet MR 3858 / 3862 4.1 245.3 8.2X -Both columns - ORC Vectorized 431 / 437 36.5 27.4 73.8X -Both column - ORC Vectorized with copy 523 / 529 30.1 33.3 60.7X -Both columns - ORC MR 2712 / 2805 5.8 172.4 11.7X +Data column - CSV 30965 / 31041 0.5 1968.7 1.0X +Data column - Json 12876 / 12882 1.2 818.6 2.4X +Data column - Parquet Vectorized 277 / 282 56.7 17.6 111.6X +Data column - Parquet MR 3398 / 3402 4.6 216.0 9.1X +Data column - ORC Vectorized 399 / 407 39.4 25.4 77.5X +Data column - ORC Vectorized with copy 407 / 447 38.6 25.9 76.0X +Data column - ORC MR 2583 / 2589 6.1 164.2 12.0X +Partition column - CSV 7403 / 7427 2.1 470.7 4.2X +Partition column - Json 5587 / 5625 2.8 355.2 5.5X +Partition column - Parquet Vectorized 71 / 78 222.6 4.5 438.3X +Partition column - Parquet MR 1798 / 1808 8.7 114.3 17.2X +Partition column - ORC Vectorized 72 / 75 219.0 4.6 431.2X +Partition column - ORC Vectorized with copy 71 / 77 221.1 4.5 435.4X +Partition column - ORC MR 1772 / 1778 8.9 112.6 17.5X +Both columns - CSV 30211 / 30212 0.5 1920.7 1.0X +Both columns - Json 13382 / 13391 1.2 850.8 2.3X +Both columns - Parquet Vectorized 321 / 333 49.0 20.4 96.4X +Both columns - Parquet MR 3656 / 3661 4.3 232.4 8.5X +Both columns - ORC Vectorized 443 / 448 35.5 28.2 69.9X +Both column - ORC Vectorized with copy 527 / 533 29.9 33.5 58.8X +Both columns - ORC MR 2626 / 2633 6.0 167.0 11.8X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13525 / 13823 0.8 1289.9 1.0X -SQL Json 9913 / 9921 1.1 945.3 1.4X -SQL Parquet Vectorized 1517 / 1517 6.9 144.7 8.9X -SQL Parquet MR 3996 / 4008 2.6 381.1 3.4X -ParquetReader Vectorized 1120 / 1128 9.4 106.8 12.1X -SQL ORC Vectorized 1203 / 1224 8.7 114.7 11.2X -SQL ORC Vectorized with copy 1639 / 1646 6.4 156.3 8.3X -SQL ORC MR 3720 / 3780 2.8 354.7 3.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13918 / 13979 0.8 1327.3 1.0X +SQL Json 10068 / 10068 1.0 960.1 1.4X +SQL Parquet Vectorized 1563 / 1564 6.7 149.0 8.9X +SQL Parquet MR 3835 / 3836 2.7 365.8 3.6X +ParquetReader Vectorized 1115 / 1118 9.4 106.4 12.5X +SQL ORC Vectorized 1172 / 1208 8.9 111.8 11.9X +SQL ORC Vectorized with copy 1630 / 1644 6.4 155.5 8.5X +SQL ORC MR 3708 / 3711 2.8 353.6 3.8X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 15860 / 15877 0.7 1512.5 1.0X -SQL Json 7676 / 7688 1.4 732.0 2.1X -SQL Parquet Vectorized 1072 / 1084 9.8 102.2 14.8X -SQL Parquet MR 2890 / 2897 3.6 275.6 5.5X -ParquetReader Vectorized 1052 / 1053 10.0 100.4 15.1X -SQL ORC Vectorized 1248 / 1248 8.4 119.0 12.7X -SQL ORC Vectorized with copy 1627 / 1637 6.4 155.2 9.7X -SQL ORC MR 3365 / 3369 3.1 320.9 4.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13972 / 14043 0.8 1332.5 1.0X +SQL Json 7436 / 7469 1.4 709.1 1.9X +SQL Parquet Vectorized 1103 / 1112 9.5 105.2 12.7X +SQL Parquet MR 2841 / 2847 3.7 271.0 4.9X +ParquetReader Vectorized 992 / 1012 10.6 94.6 14.1X +SQL ORC Vectorized 1275 / 1349 8.2 121.6 11.0X +SQL ORC Vectorized with copy 1631 / 1644 6.4 155.5 8.6X +SQL ORC MR 3244 / 3259 3.2 309.3 4.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13401 / 13561 0.8 1278.1 1.0X -SQL Json 5253 / 5303 2.0 500.9 2.6X -SQL Parquet Vectorized 233 / 242 45.0 22.2 57.6X -SQL Parquet MR 1791 / 1796 5.9 170.8 7.5X -ParquetReader Vectorized 236 / 238 44.4 22.5 56.7X -SQL ORC Vectorized 453 / 473 23.2 43.2 29.6X -SQL ORC Vectorized with copy 573 / 577 18.3 54.7 23.4X -SQL ORC MR 1846 / 1850 5.7 176.0 7.3X +SQL CSV 11228 / 11244 0.9 1070.8 1.0X +SQL Json 5200 / 5247 2.0 495.9 2.2X +SQL Parquet Vectorized 238 / 242 44.1 22.7 47.2X +SQL Parquet MR 1730 / 1734 6.1 165.0 6.5X +ParquetReader Vectorized 237 / 238 44.3 22.6 47.4X +SQL ORC Vectorized 459 / 462 22.8 43.8 24.4X +SQL ORC Vectorized with copy 581 / 583 18.1 55.4 19.3X +SQL ORC MR 1767 / 1783 5.9 168.5 6.4X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 3147 / 3148 0.3 3001.1 1.0X -SQL Json 2666 / 2693 0.4 2542.9 1.2X -SQL Parquet Vectorized 54 / 58 19.5 51.3 58.5X -SQL Parquet MR 220 / 353 4.8 209.9 14.3X -SQL ORC Vectorized 63 / 77 16.8 59.7 50.3X -SQL ORC Vectorized with copy 63 / 66 16.7 59.8 50.2X -SQL ORC MR 317 / 321 3.3 302.2 9.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 3322 / 3356 0.3 3167.9 1.0X +SQL Json 2808 / 2843 0.4 2678.2 1.2X +SQL Parquet Vectorized 56 / 63 18.9 52.9 59.8X +SQL Parquet MR 215 / 219 4.9 205.4 15.4X +SQL ORC Vectorized 64 / 76 16.4 60.9 52.0X +SQL ORC Vectorized with copy 64 / 67 16.3 61.3 51.7X +SQL ORC MR 314 / 316 3.3 299.6 10.6X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 7902 / 7921 0.1 7536.2 1.0X -SQL Json 9467 / 9491 0.1 9028.6 0.8X -SQL Parquet Vectorized 73 / 79 14.3 69.8 108.0X -SQL Parquet MR 239 / 247 4.4 228.0 33.1X -SQL ORC Vectorized 78 / 84 13.4 74.6 101.0X -SQL ORC Vectorized with copy 78 / 88 13.4 74.4 101.3X -SQL ORC MR 910 / 918 1.2 867.6 8.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 7978 / 7989 0.1 7608.5 1.0X +SQL Json 10294 / 10325 0.1 9816.9 0.8X +SQL Parquet Vectorized 72 / 85 14.5 69.0 110.3X +SQL Parquet MR 237 / 241 4.4 226.4 33.6X +SQL ORC Vectorized 82 / 92 12.7 78.5 97.0X +SQL ORC Vectorized with copy 82 / 88 12.7 78.5 97.0X +SQL ORC MR 900 / 909 1.2 858.5 8.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13539 / 13543 0.1 12912.0 1.0X -SQL Json 17420 / 17446 0.1 16613.1 0.8X -SQL Parquet Vectorized 103 / 120 10.2 98.1 131.6X -SQL Parquet MR 250 / 258 4.2 238.9 54.1X -SQL ORC Vectorized 99 / 104 10.6 94.6 136.5X -SQL ORC Vectorized with copy 100 / 106 10.5 95.6 135.1X -SQL ORC MR 1653 / 1659 0.6 1576.3 8.2X +SQL CSV 13489 / 13508 0.1 12864.3 1.0X +SQL Json 18813 / 18827 0.1 17941.4 0.7X +SQL Parquet Vectorized 107 / 111 9.8 101.8 126.3X +SQL Parquet MR 275 / 286 3.8 262.3 49.0X +SQL ORC Vectorized 107 / 115 9.8 101.7 126.4X +SQL ORC Vectorized with copy 107 / 115 9.8 102.3 125.8X +SQL ORC MR 1659 / 1664 0.6 1582.3 8.1X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index a1f51f8e54805..ecd9ead0ae39a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -447,7 +447,9 @@ object DataSourceReadBenchmark extends BenchmarkBase with SQLHelper { } def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - val benchmark = new Benchmark("String with Nulls Scan", values, output = output) + val percentageOfNulls = fractionOfNulls * 100 + val benchmark = + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt index c77f966723d71..80c2f5e93405a 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -2,172 +2,172 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1630 / 1639 9.7 103.6 1.0X -Native ORC Vectorized 253 / 288 62.2 16.1 6.4X -Native ORC Vectorized with copy 227 / 244 69.2 14.5 7.2X -Hive built-in ORC 1980 / 1991 7.9 125.9 0.8X +Native ORC MR 1725 / 1759 9.1 109.7 1.0X +Native ORC Vectorized 272 / 316 57.8 17.3 6.3X +Native ORC Vectorized with copy 239 / 254 65.7 15.2 7.2X +Hive built-in ORC 1970 / 1987 8.0 125.3 0.9X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1587 / 1589 9.9 100.9 1.0X -Native ORC Vectorized 227 / 242 69.2 14.5 7.0X -Native ORC Vectorized with copy 228 / 238 69.0 14.5 7.0X -Hive built-in ORC 2323 / 2332 6.8 147.7 0.7X +Native ORC MR 1633 / 1672 9.6 103.8 1.0X +Native ORC Vectorized 238 / 255 66.0 15.1 6.9X +Native ORC Vectorized with copy 235 / 253 66.8 15.0 6.9X +Hive built-in ORC 2293 / 2305 6.9 145.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1726 / 1771 9.1 109.7 1.0X -Native ORC Vectorized 309 / 333 50.9 19.7 5.6X -Native ORC Vectorized with copy 313 / 321 50.2 19.9 5.5X -Hive built-in ORC 2668 / 2672 5.9 169.6 0.6X +Native ORC MR 1677 / 1699 9.4 106.6 1.0X +Native ORC Vectorized 325 / 342 48.3 20.7 5.2X +Native ORC Vectorized with copy 328 / 341 47.9 20.9 5.1X +Hive built-in ORC 2561 / 2569 6.1 162.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1722 / 1747 9.1 109.5 1.0X -Native ORC Vectorized 395 / 403 39.8 25.1 4.4X -Native ORC Vectorized with copy 399 / 405 39.4 25.4 4.3X -Hive built-in ORC 2767 / 2777 5.7 175.9 0.6X +Native ORC MR 1791 / 1795 8.8 113.9 1.0X +Native ORC Vectorized 400 / 408 39.3 25.4 4.5X +Native ORC Vectorized with copy 410 / 417 38.4 26.1 4.4X +Hive built-in ORC 2713 / 2720 5.8 172.5 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1797 / 1824 8.8 114.2 1.0X -Native ORC Vectorized 434 / 441 36.2 27.6 4.1X -Native ORC Vectorized with copy 437 / 447 36.0 27.8 4.1X -Hive built-in ORC 2701 / 2710 5.8 171.7 0.7X +Native ORC MR 1791 / 1805 8.8 113.8 1.0X +Native ORC Vectorized 433 / 438 36.3 27.5 4.1X +Native ORC Vectorized with copy 441 / 447 35.7 28.0 4.1X +Hive built-in ORC 2690 / 2803 5.8 171.0 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1931 / 2028 8.1 122.8 1.0X -Native ORC Vectorized 542 / 557 29.0 34.5 3.6X -Native ORC Vectorized with copy 550 / 564 28.6 35.0 3.5X -Hive built-in ORC 2816 / 3206 5.6 179.1 0.7X +Native ORC MR 1911 / 1930 8.2 121.5 1.0X +Native ORC Vectorized 543 / 552 29.0 34.5 3.5X +Native ORC Vectorized with copy 547 / 555 28.8 34.8 3.5X +Hive built-in ORC 2967 / 3065 5.3 188.6 0.6X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 4012 / 4068 2.6 382.6 1.0X -Native ORC Vectorized 2337 / 2339 4.5 222.9 1.7X -Native ORC Vectorized with copy 2520 / 2540 4.2 240.3 1.6X -Hive built-in ORC 5503 / 5575 1.9 524.8 0.7X +Native ORC MR 4160 / 4188 2.5 396.7 1.0X +Native ORC Vectorized 2405 / 2406 4.4 229.4 1.7X +Native ORC Vectorized with copy 2588 / 2592 4.1 246.8 1.6X +Hive built-in ORC 5514 / 5562 1.9 525.9 0.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 2020 / 2025 7.8 128.4 1.0X -Data column - Native ORC Vectorized 398 / 409 39.5 25.3 5.1X -Data column - Native ORC Vectorized with copy 406 / 411 38.8 25.8 5.0X -Data column - Hive built-in ORC 2967 / 2969 5.3 188.6 0.7X -Partition column - Native ORC MR 1494 / 1505 10.5 95.0 1.4X -Partition column - Native ORC Vectorized 73 / 82 216.3 4.6 27.8X -Partition column - Native ORC Vectorized with copy 71 / 80 221.4 4.5 28.4X -Partition column - Hive built-in ORC 1932 / 1937 8.1 122.8 1.0X -Both columns - Native ORC MR 2057 / 2071 7.6 130.8 1.0X -Both columns - Native ORC Vectorized 445 / 448 35.4 28.3 4.5X -Both column - Native ORC Vectorized with copy 534 / 539 29.4 34.0 3.8X -Both columns - Hive built-in ORC 2994 / 2994 5.3 190.3 0.7X +Data column - Native ORC MR 1863 / 1867 8.4 118.4 1.0X +Data column - Native ORC Vectorized 411 / 418 38.2 26.2 4.5X +Data column - Native ORC Vectorized with copy 417 / 422 37.8 26.5 4.5X +Data column - Hive built-in ORC 3297 / 3308 4.8 209.6 0.6X +Partition column - Native ORC MR 1505 / 1506 10.4 95.7 1.2X +Partition column - Native ORC Vectorized 80 / 93 195.6 5.1 23.2X +Partition column - Native ORC Vectorized with copy 78 / 86 201.4 5.0 23.9X +Partition column - Hive built-in ORC 1960 / 1979 8.0 124.6 1.0X +Both columns - Native ORC MR 2076 / 2090 7.6 132.0 0.9X +Both columns - Native ORC Vectorized 450 / 463 34.9 28.6 4.1X +Both column - Native ORC Vectorized with copy 532 / 538 29.6 33.8 3.5X +Both columns - Hive built-in ORC 3528 / 3548 4.5 224.3 0.5X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1771 / 1785 5.9 168.9 1.0X -Native ORC Vectorized 372 / 375 28.2 35.5 4.8X -Native ORC Vectorized with copy 543 / 576 19.3 51.8 3.3X -Hive built-in ORC 2671 / 2671 3.9 254.7 0.7X +Native ORC MR 1727 / 1733 6.1 164.7 1.0X +Native ORC Vectorized 375 / 379 28.0 35.7 4.6X +Native ORC Vectorized with copy 552 / 556 19.0 52.6 3.1X +Hive built-in ORC 2665 / 2666 3.9 254.2 0.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3276 / 3302 3.2 312.5 1.0X -Native ORC Vectorized 1057 / 1080 9.9 100.8 3.1X -Native ORC Vectorized with copy 1420 / 1431 7.4 135.4 2.3X -Hive built-in ORC 5377 / 5407 2.0 512.8 0.6X +Native ORC MR 3324 / 3325 3.2 317.0 1.0X +Native ORC Vectorized 1085 / 1106 9.7 103.4 3.1X +Native ORC Vectorized with copy 1463 / 1471 7.2 139.5 2.3X +Hive built-in ORC 5272 / 5299 2.0 502.8 0.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3147 / 3147 3.3 300.1 1.0X -Native ORC Vectorized 1305 / 1319 8.0 124.4 2.4X -Native ORC Vectorized with copy 1685 / 1686 6.2 160.7 1.9X -Hive built-in ORC 4077 / 4085 2.6 388.8 0.8X +Native ORC MR 3045 / 3046 3.4 290.4 1.0X +Native ORC Vectorized 1248 / 1260 8.4 119.0 2.4X +Native ORC Vectorized with copy 1609 / 1624 6.5 153.5 1.9X +Hive built-in ORC 3989 / 3999 2.6 380.4 0.8X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1739 / 1744 6.0 165.8 1.0X -Native ORC Vectorized 500 / 501 21.0 47.7 3.5X -Native ORC Vectorized with copy 618 / 631 17.0 58.9 2.8X -Hive built-in ORC 2411 / 2427 4.3 229.9 0.7X +Native ORC MR 1692 / 1694 6.2 161.3 1.0X +Native ORC Vectorized 471 / 493 22.3 44.9 3.6X +Native ORC Vectorized with copy 588 / 590 17.8 56.1 2.9X +Hive built-in ORC 2398 / 2411 4.4 228.7 0.7X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1348 / 1366 0.8 1285.3 1.0X -Native ORC Vectorized 119 / 134 8.8 113.5 11.3X -Native ORC Vectorized with copy 119 / 148 8.8 113.9 11.3X -Hive built-in ORC 487 / 507 2.2 464.8 2.8X +Native ORC MR 1371 / 1379 0.8 1307.5 1.0X +Native ORC Vectorized 121 / 135 8.6 115.8 11.3X +Native ORC Vectorized with copy 122 / 138 8.6 116.2 11.3X +Hive built-in ORC 521 / 561 2.0 497.1 2.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 2667 / 2837 0.4 2543.6 1.0X -Native ORC Vectorized 203 / 222 5.2 193.4 13.2X -Native ORC Vectorized with copy 217 / 255 4.8 207.0 12.3X -Hive built-in ORC 737 / 741 1.4 702.4 3.6X +Native ORC MR 2711 / 2767 0.4 2585.5 1.0X +Native ORC Vectorized 210 / 232 5.0 200.5 12.9X +Native ORC Vectorized with copy 208 / 219 5.0 198.4 13.0X +Hive built-in ORC 764 / 775 1.4 728.3 3.5X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3954 / 3956 0.3 3770.4 1.0X -Native ORC Vectorized 348 / 360 3.0 331.7 11.4X -Native ORC Vectorized with copy 349 / 359 3.0 333.2 11.3X -Hive built-in ORC 1057 / 1067 1.0 1008.0 3.7X +Native ORC MR 3979 / 3988 0.3 3794.4 1.0X +Native ORC Vectorized 357 / 366 2.9 340.2 11.2X +Native ORC Vectorized with copy 361 / 371 2.9 344.5 11.0X +Hive built-in ORC 1091 / 1095 1.0 1040.5 3.6X diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index ec13288f759a6..eb3cde8472dac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -32,9 +32,11 @@ import org.apache.spark.sql.types._ * Benchmark to measure ORC read performance. * {{{ * To run this benchmark: - * 1. without sbt: bin/spark-submit --class - * 2. build/sbt "sql/test:runMain " - * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * 1. without sbt: bin/spark-submit --class + * --jars ,,,, + * + * 2. build/sbt "hive/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain " * Results will be written to "benchmarks/OrcReadBenchmark-results.txt". * }}} * @@ -266,8 +268,9 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + val percentageOfNulls = fractionOfNulls * 100 val benchmark = - new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values, output = output) + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { From 973f7c01df0788b6f5d21224d96c33f14c5b8c64 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Nov 2018 15:49:36 -0800 Subject: [PATCH 009/145] [MINOR] update HiveExternalCatalogVersionsSuite to test 2.4.0 ## What changes were proposed in this pull request? Since Spark 2.4.0 is released, we should test it in HiveExternalCatalogVersionsSuite ## How was this patch tested? N/A Closes #22984 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index fd4985d131885..f1e842334416c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.1.3", "2.2.2", "2.3.2") + val testingVersions = Seq("2.2.2", "2.3.2", "2.4.0") protected var spark: SparkSession = _ From 79551f558dafed41177b605b0436e9340edf5712 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 9 Nov 2018 09:45:06 +0800 Subject: [PATCH 010/145] [SPARK-25945][SQL] Support locale while parsing date/timestamp from CSV/JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the PR, I propose to add new option `locale` into CSVOptions/JSONOptions to make parsing date/timestamps in local languages possible. Currently the locale is hard coded to `Locale.US`. ## How was this patch tested? Added two tests for parsing a date from CSV/JSON - `ноя 2018`. Closes #22951 from MaxGekk/locale. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/readwriter.py | 15 +++++++++++---- python/pyspark/sql/streaming.py | 14 ++++++++++---- .../spark/sql/catalyst/csv/CSVOptions.scala | 7 +++++-- .../spark/sql/catalyst/json/JSONOptions.scala | 7 +++++-- .../expressions/CsvExpressionsSuite.scala | 19 ++++++++++++++++++- .../expressions/JsonExpressionsSuite.scala | 19 ++++++++++++++++++- .../apache/spark/sql/DataFrameReader.scala | 4 ++++ .../sql/streaming/DataStreamReader.scala | 4 ++++ .../apache/spark/sql/CsvFunctionsSuite.scala | 17 +++++++++++++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 17 +++++++++++++++++ 10 files changed, 109 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 690b13072244b..726de4a965418 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - dropFieldIfAllNull=None, encoding=None): + dropFieldIfAllNull=None, encoding=None, locale=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param dropFieldIfAllNull: whether to ignore column of all null values or empty array/struct during schema inference. If None is set, it uses the default value, ``false``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, + locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, ``1.0``. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index b18453b2a4f96..02b14ea187cba 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None): + enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise.. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue) + emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index cdaaa172e8367..642823582a645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -131,13 +131,16 @@ class CSVOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 64152e04928d2..e10b8a327c01a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -76,16 +76,19 @@ private[sql] class JSONOptions( // Whether to ignore column of all null values or empty array/struct during schema inference val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d006197bd5678..f5aaaec456153 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -209,4 +210,20 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P "2015-12-31T16:00:00" ) } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = sdf.format(date) + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + CsvToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 304642161146b..6ee8c74010d3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), "struct") } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = s"""{"d":"${sdf.format(date)}"}""" + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + JsonToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 95c97e5c9433c..02ffc940184db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -384,6 +384,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * for schema inferring. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -604,6 +606,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4c7dcedafeeae..20c84305776ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -296,6 +296,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * that should be used for parsing. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -372,6 +374,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index b97ac380def63..1c359ce1d2014 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -164,4 +167,18 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""${sdf.format(ts)}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index d6b73387e84b3..24e7564259c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import collection.JavaConverters._ import org.apache.spark.SparkException @@ -591,4 +594,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_json($"value", "time timestamp", options)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } From 0558d021cc0aeae37ef0e043d244fd0300a57cd5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 9 Nov 2018 11:45:03 +0800 Subject: [PATCH 011/145] [SPARK-25510][SQL][TEST][FOLLOW-UP] Remove BenchmarkWithCodegen ## What changes were proposed in this pull request? Remove `BenchmarkWithCodegen` as we don't use it anymore. More details: https://github.com/apache/spark/pull/22484#discussion_r221397904 ## How was this patch tested? N/A Closes #22985 from wangyum/SPARK-25510. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon --- .../benchmark/BenchmarkWithCodegen.scala | 54 ------------------- 1 file changed, 54 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala deleted file mode 100644 index 51331500479a3..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import org.apache.spark.SparkFunSuite -import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.SparkSession - -/** - * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together - * with other test suites). - */ -private[benchmark] trait BenchmarkWithCodegen extends SparkFunSuite { - - lazy val sparkSession = SparkSession.builder - .master("local[1]") - .appName("microbenchmark") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.sql.autoBroadcastJoinThreshold", 1) - .getOrCreate() - - /** Runs function `f` with whole stage codegen on and off. */ - def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality) - - benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) - f - } - - benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - f - } - - benchmark.run() - } - -} From 297b81e0eb1493b12838c3c48c6f754289ce1c1f Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Fri, 9 Nov 2018 07:55:02 -0600 Subject: [PATCH 012/145] [SPARK-20156][SQL][ML][FOLLOW-UP] Java String toLowerCase with Locale.ROOT ## What changes were proposed in this pull request? Add `Locale.ROOT` to all internal calls to String `toLowerCase`, `toUpperCase` ## How was this patch tested? existing tests Closes #22975 from zhengruifeng/Tokenizer_Locale. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- project/SparkBuild.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ca57df0e31a7f..5e034f9fe2a95 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,6 +17,7 @@ import java.io._ import java.nio.file.Files +import java.util.Locale import scala.io.Source import scala.util.Properties @@ -650,10 +651,13 @@ object Assembly { }, jarName in (Test, assembly) := s"${moduleName.value}-test-${version.value}.jar", mergeStrategy in assembly := { - case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") + => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).matches("meta-inf.*\\.sf$") + => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard - case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines + case m if m.toLowerCase(Locale.ROOT).startsWith("meta-inf/services/") + => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } From 25f506e2ad865ed671cfc618ca9d272bfb5712b7 Mon Sep 17 00:00:00 2001 From: William Montaz Date: Fri, 9 Nov 2018 08:02:53 -0600 Subject: [PATCH 013/145] [SPARK-25973][CORE] Spark History Main page performance improvement HistoryPage.scala counts applications (with a predicate depending on if it is displaying incomplete or complete applications) to check if it must display the dataTable. Since it only checks if allAppsSize > 0, we could use exists method on the iterator. This way we stop iterating at the first occurence found. Such a change has been relevant (roughly 12s improvement on page loading) on our cluster that runs tens of thousands of jobs per day. Closes #22982 from Willymontaz/SPARK-25973. Authored-by: William Montaz Signed-off-by: Sean Owen --- .../org/apache/spark/deploy/history/HistoryPage.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 32667ddf5c7ea..00ca4efa4d266 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -31,8 +31,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean - val allAppsSize = parent.getApplicationList() - .count(isApplicationCompleted(_) != requestedIncomplete) + val displayApplications = parent.getApplicationList() + .exists(isApplicationCompleted(_) != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() @@ -63,9 +63,9 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } { - if (allAppsSize > 0) { + if (displayApplications) { ++ + request, "/static/dataTables.rowsGroup.js")}> ++
    ++ ++ From 657fd00b5204859c2e6d7c19a71a3ec5ecf7c869 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Nov 2018 08:22:26 -0800 Subject: [PATCH 014/145] [SPARK-25988][SQL] Keep names unchanged when deduplicating the column names in Analyzer ## What changes were proposed in this pull request? When the queries do not use the column names with the same case, users might hit various errors. Below is a typical test failure they can hit. ``` Expected only partition pruning predicates: ArrayBuffer(isnotnull(tdate#237), (cast(tdate#237 as string) >= 2017-08-15)); org.apache.spark.sql.AnalysisException: Expected only partition pruning predicates: ArrayBuffer(isnotnull(tdate#237), (cast(tdate#237 as string) >= 2017-08-15)); at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.prunePartitionsByFilter(ExternalCatalogUtils.scala:146) at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.listPartitionsByFilter(InMemoryCatalog.scala:560) at org.apache.spark.sql.catalyst.catalog.SessionCatalog.listPartitionsByFilter(SessionCatalog.scala:925) ``` ## How was this patch tested? Added two test cases. Closes #22990 from gatorsmile/fix1283. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../sql/catalyst/analysis/unresolved.scala | 1 + .../expressions/namedExpressions.scala | 5 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 53 +++++++++++++++++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2d22c5e7ce60..6dc5b3f28b914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -824,7 +824,8 @@ class Analyzer( } private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 857cf382b8f2c..36cad3cf74785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -112,6 +112,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this + override def withExprId(newExprId: ExprId): UnresolvedAttribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 584a2946bd564..049ea77691395 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -115,6 +115,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute + def withExprId(newExprId: ExprId): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -299,7 +300,7 @@ case class AttributeReference( } } - def withExprId(newExprId: ExprId): AttributeReference = { + override def withExprId(newExprId: ExprId): AttributeReference = { if (exprId == newExprId) { this } else { @@ -362,6 +363,8 @@ case class PrettyAttribute( throw new UnsupportedOperationException override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + override def withExprId(newExprId: ExprId): Attribute = + throw new UnsupportedOperationException override def nullable: Boolean = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 631ab1b7ece7f..dbb0790a4682c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2856,6 +2856,59 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) } } + + test("SPARK-25988: self join with aliases on partitioned tables #1") { + withTempView("tmpView1", "tmpView2") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`col1` INT, `TDATE` DATE) + |USING CSV + |PARTITIONED BY (TDATE) + """.stripMargin) + spark.table("tab1").where("TDATE >= '2017-08-15'").createOrReplaceTempView("tmpView1") + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW tmpView2 AS + |SELECT N.tdate, col1 AS aliasCol1 + |FROM tmpView1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + sql("SELECT * FROM tmpView2 x JOIN tmpView2 y ON x.tdate = y.tdate").collect() + } + } + } + } + + test("SPARK-25988: self join with aliases on partitioned tables #2") { + withTempView("tmp") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`EX` STRING, `TDATE` DATE) + |USING parquet + |PARTITIONED BY (tdate) + """.stripMargin) + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW TMP as + |SELECT N.tdate, EX AS new_ex + |FROM tab1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + sql( + """ + |SELECT * FROM TMP x JOIN TMP y + |ON x.tdate = y.tdate + """.stripMargin).queryExecution.executedPlan + } + } + } } case class Foo(bar: Option[String]) From 1db799795cf3c15798fbfb6043ec5775e16ba5ea Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 9 Nov 2018 09:44:04 -0800 Subject: [PATCH 015/145] [SPARK-25979][SQL] Window function: allow parentheses around window reference ## What changes were proposed in this pull request? Very minor parser bug, but possibly problematic for code-generated queries: Consider the following two queries: ``` SELECT avg(k) OVER (w) FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 ``` and ``` SELECT avg(k) OVER w FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 ``` The former, with parens around the OVER condition, fails to parse while the latter, without parens, succeeds: ``` Error in SQL statement: ParseException: mismatched input '(' expecting {, ',', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE'}(line 1, pos 19) == SQL == SELECT avg(k) OVER (w) FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 -------------------^^^ ``` This was found when running the cockroach DB tests. I tried PostgreSQL, The SQL with parentheses is also workable. ## How was this patch tested? Unit test Closes #22987 from gengliangwang/windowParentheses. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../resources/sql-tests/inputs/window.sql | 6 ++++++ .../sql-tests/results/window.sql.out | 19 ++++++++++++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e2d34d1650ddc..5e732edb17baa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -691,6 +691,7 @@ namedWindow windowSpec : name=identifier #windowRef + | '('name=identifier')' #windowRef | '(' ( CLUSTER BY partition+=expression (',' partition+=expression)* | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index cda4db4b449fe..faab4c61c8640 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -109,3 +109,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val; + +-- parentheses around window reference +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val); diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 5071e0bd26b2a..367dc4f513635 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 23 -- !query 0 @@ -363,3 +363,20 @@ NULL a false true false false true false 1 b false true false false true false 2 b false true false false true false 3 b false true false false true false + + +-- !query 22 +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val) +-- !query 22 schema +struct +-- !query 22 output +NULL 3 +a 2 +a 2 +a 4 +b 1 +b 3 +b 6 From 8e5f3c6ba6ef9b92578a6b292cfa1c480370cbfc Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 9 Nov 2018 15:40:15 -0600 Subject: [PATCH 016/145] [SPARK-24101][ML][MLLIB] ML Evaluators should use weight column - added weight column for multiclass classification evaluator ## What changes were proposed in this pull request? The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data. I've closed the PR: https://github.com/apache/spark/pull/16557 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update. Note: I've updated the JIRA to: https://issues.apache.org/jira/browse/SPARK-24101 Which is a child of JIRA: https://issues.apache.org/jira/browse/SPARK-18693 ## How was this patch tested? I added tests to the metrics class. Closes #17086 from imatiach-msft/ilmat/multiclass-evaluate. Authored-by: Ilya Matiach Signed-off-by: Sean Owen --- .../MulticlassClassificationEvaluator.scala | 19 ++- .../mllib/evaluation/MulticlassMetrics.scala | 55 +++--- .../evaluation/MulticlassMetricsSuite.scala | 158 ++++++++++++++---- 3 files changed, 170 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d881..f1602c1bc5333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 980e0c92531a2..ad83c24ede964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case (prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") + } /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (_: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 5394baab94bcf..8779de590a256 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,10 +18,14 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.ml.linalg.Matrices +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val delta = 1e-7 + test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) - val delta = 0.0000001 val tpRate0 = 2.0 / (2 + 2) val tpRate1 = 3.0 / (3 + 1) val tpRate2 = 1.0 / (1 + 0) @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) + + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) + } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) } } From d66a4e82eceb89a274edeb22c2fb4384bed5078b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Nov 2018 22:42:48 -0800 Subject: [PATCH 017/145] [SPARK-25102][SQL] Write Spark version to ORC/Parquet file metadata ## What changes were proposed in this pull request? Currently, Spark writes Spark version number into Hive Table properties with `spark.sql.create.version`. ``` parameters:{ spark.sql.sources.schema.part.0={ "type":"struct", "fields":[{"name":"a","type":"integer","nullable":true,"metadata":{}}] }, transient_lastDdlTime=1541142761, spark.sql.sources.schema.numParts=1, spark.sql.create.version=2.4.0 } ``` This PR aims to write Spark versions to ORC/Parquet file metadata with `org.apache.spark.sql.create.version` because we used `org.apache.` prefix in Parquet metadata already. It's different from Hive Table property key `spark.sql.create.version`, but it seems that we cannot change Hive Table property for backward compatibility. After this PR, ORC and Parquet file generated by Spark will have the following metadata. **ORC (`native` and `hive` implmentation)** ``` $ orc-tools meta /tmp/o File Version: 0.12 with ... ... User Metadata: org.apache.spark.sql.create.version=3.0.0 ``` **PARQUET** ``` $ parquet-tools meta /tmp/p ... creator: parquet-mr version 1.10.0 (build 031a6654009e3b82020012a18434c582bd74c73a) extra: org.apache.spark.sql.create.version = 3.0.0 extra: org.apache.spark.sql.parquet.row.metadata = {"type":"struct","fields":[{"name":"id","type":"long","nullable":false,"metadata":{}}]} ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. This closes #22255. Closes #22932 from dongjoon-hyun/SPARK-25102. Authored-by: Dongjoon Hyun Signed-off-by: gatorsmile --- .../main/scala/org/apache/spark/package.scala | 3 +++ .../org/apache/spark/util/VersionUtils.scala | 14 ++++++++++ .../apache/spark/util/VersionUtilsSuite.scala | 25 ++++++++++++++++++ .../datasources/orc/OrcOutputWriter.scala | 15 ++++++++--- .../execution/datasources/orc/OrcUtils.scala | 14 +++++++--- .../parquet/ParquetWriteSupport.scala | 7 ++++- .../scala/org/apache/spark/sql/package.scala | 9 +++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 4 +-- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../datasources/orc/OrcSourceSuite.scala | 20 +++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 21 ++++++++++++++- .../spark/sql/hive/orc/OrcFileFormat.scala | 26 ++++++++++++++++--- 12 files changed, 144 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8058a4d5dbdea..5d0639e92c36a 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -19,6 +19,8 @@ package org.apache import java.util.Properties +import org.apache.spark.util.VersionUtils + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, @@ -89,6 +91,7 @@ package object spark { } val SPARK_VERSION = SparkBuildInfo.spark_version + val SPARK_VERSION_SHORT = VersionUtils.shortVersion(SparkBuildInfo.spark_version) val SPARK_BRANCH = SparkBuildInfo.spark_branch val SPARK_REVISION = SparkBuildInfo.spark_revision val SPARK_BUILD_USER = SparkBuildInfo.spark_build_user diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala index 828153b868420..c0f8866dd58dc 100644 --- a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala @@ -23,6 +23,7 @@ package org.apache.spark.util private[spark] object VersionUtils { private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r + private val shortVersionRegex = """^(\d+\.\d+\.\d+)(.*)?$""".r /** * Given a Spark version string, return the major version number. @@ -36,6 +37,19 @@ private[spark] object VersionUtils { */ def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2 + /** + * Given a Spark version string, return the short version string. + * E.g., for 3.0.0-SNAPSHOT, return '3.0.0'. + */ + def shortVersion(sparkVersion: String): String = { + shortVersionRegex.findFirstMatchIn(sparkVersion) match { + case Some(m) => m.group(1) + case None => + throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major/minor/maintenance version numbers.") + } + } + /** * Given a Spark version string, return the (major version number, minor version number). * E.g., for 2.0.1-SNAPSHOT, return (2, 0). diff --git a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala index b36d6be231d39..56623ebea1651 100644 --- a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala @@ -73,4 +73,29 @@ class VersionUtilsSuite extends SparkFunSuite { } } } + + test("Return short version number") { + assert(shortVersion("3.0.0") === "3.0.0") + assert(shortVersion("3.0.0-SNAPSHOT") === "3.0.0") + withClue("shortVersion parsing should fail for missing maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0") + } + } + withClue("shortVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + shortVersion("x.0.0") + } + } + withClue("shortVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + shortVersion("3.x.0") + } + } + withClue("shortVersion parsing should fail for invalid maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0.x") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 84755bfa301f0..7e38fc651a31f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.orc.mapred.OrcStruct -import org.apache.orc.mapreduce.OrcOutputFormat +import org.apache.orc.OrcFile +import org.apache.orc.mapred.{OrcOutputFormat => OrcMapRedOutputFormat, OrcStruct} +import org.apache.orc.mapreduce.{OrcMapreduceRecordWriter, OrcOutputFormat} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter @@ -36,11 +37,17 @@ private[orc] class OrcOutputWriter( private[this] val serializer = new OrcSerializer(dataSchema) private val recordWriter = { - new OrcOutputFormat[OrcStruct]() { + val orcOutputFormat = new OrcOutputFormat[OrcStruct]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { new Path(path) } - }.getRecordWriter(context) + } + val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") + val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) + val writer = OrcFile.createWriter(filename, options) + val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) + OrcUtils.addSparkVersionMetadata(writer) + recordWriter } override def write(row: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 95fb25bf5addb..57d2c56e87b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.execution.datasources.orc +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -144,4 +145,11 @@ object OrcUtils extends Logging { } } } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(writer: Writer): Unit = { + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index b40b8c2e61f33..8814e3c6ccf94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -29,7 +29,9 @@ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -93,7 +95,10 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) - val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> schemaString + ).asJava logInfo( s"""Initialized Parquet WriteSupport with Catalyst schema: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 161e0102f0b43..354660e9d5943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,4 +44,13 @@ package object sql { type Strategy = SparkStrategy type DataFrame = Dataset[Row] + + /** + * Metadata key which is used to write Spark version in the followings: + * - Parquet file metadata + * - ORC file metadata + * + * Note that Hive table property `spark.sql.create.version` also has Spark version. + */ + private[sql] val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index e1567d06e23eb..861aa179a4a81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -506,7 +506,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 800) + assert(inMemoryRelation.computeStats().sizeInBytes === 868) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -519,7 +519,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 800) + assert(inMemoryRelation2.computeStats().sizeInBytes === 868) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index c1f2c18d1417d..6e08ee3c4ba3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -45,7 +45,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { import testImplicits._ Seq(1.0, 0.5).foreach { compressionFactor => withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, - "spark.sql.autoBroadcastJoinThreshold" -> "400") { + "spark.sql.autoBroadcastJoinThreshold" -> "434") { withTempPath { workDir => // the file size is 740 bytes val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index dc81c0585bf18..48910103e702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 import java.sql.Timestamp import java.util.Locale @@ -30,7 +31,8 @@ import org.apache.orc.OrcProto.Stream.Kind import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.Row +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -314,6 +316,22 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) } } + + test("Write Spark version into ORC file metadata") { + withTempPath { path => + spark.range(1).repartition(1).write.orc(path.getCanonicalPath) + + val partFiles = path.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + val orcFilePath = new Path(partFiles.head.getAbsolutePath) + val readerOptions = OrcFile.readerOptions(new Configuration()) + val reader = OrcFile.createReader(orcFilePath, readerOptions) + val version = UTF_8.decode(reader.getMetadataValue(SPARK_VERSION_METADATA_KEY)).toString + assert(version === SPARK_VERSION_SHORT) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 002c42f23bd64..6b05b9c0f7207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.HadoopReadOptions import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.example.data.simple.SimpleGroup @@ -34,10 +35,11 @@ import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -799,6 +801,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) } } + + test("Write Spark version into Parquet metadata") { + withTempPath { dir => + val path = dir.getAbsolutePath + spark.range(1).repartition(1).write.parquet(path) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + val conf = new Configuration() + val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), conf) + val parquetReadOptions = HadoopReadOptions.builder(conf).build() + val m = ParquetFileReader.open(hadoopInputFile, parquetReadOptions) + val metaData = m.getFileMetaData.getKeyValueMetaData + m.close() + + assert(metaData.get(SPARK_VERSION_METADATA_KEY) === SPARK_VERSION_SHORT) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 89e6ea8604974..4e641e34c18d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.hive.orc import java.net.URI +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Properties import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -31,10 +33,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.orc.OrcConf.COMPRESS -import org.apache.spark.TaskContext +import org.apache.spark.{SPARK_VERSION_SHORT, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -274,12 +278,14 @@ private[orc] class OrcOutputWriter( override def close(): Unit = { if (recordWriterInstantiated) { + // Hive 1.2.1 ORC initializes its private `writer` field at the first write. + OrcFileFormat.addSparkVersionMetadata(recordWriter) recordWriter.close(Reporter.NULL) } } } -private[orc] object OrcFileFormat extends HiveInspectors { +private[orc] object OrcFileFormat extends HiveInspectors with Logging { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" @@ -339,4 +345,18 @@ private[orc] object OrcFileFormat extends HiveInspectors { val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(recordWriter: RecordWriter[NullWritable, Writable]): Unit = { + try { + val writerField = recordWriter.getClass.getDeclaredField("writer") + writerField.setAccessible(true) + val writer = writerField.get(recordWriter).asInstanceOf[Writer] + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) + } + } } From 2d085c13b7f715dbff23dd1f81af45ff903d1a79 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 10 Nov 2018 09:52:14 -0600 Subject: [PATCH 018/145] [SPARK-25984][CORE][SQL][STREAMING] Remove deprecated .newInstance(), primitive box class constructor calls ## What changes were proposed in this pull request? Deprecated in Java 11, replace Class.newInstance with Class.getConstructor.getInstance, and primtive wrapper class constructors with valueOf or equivalent ## How was this patch tested? Existing tests. Closes #22988 from srowen/SPARK-25984. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/api/python/PythonHadoopUtil.scala | 3 +- .../scala/org/apache/spark/api/r/SerDe.scala | 6 ++-- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../io/HadoopMapReduceCommitProtocol.scala | 2 +- .../spark/internal/io/SparkHadoopWriter.scala | 4 +-- .../apache/spark/metrics/MetricsSystem.scala | 2 +- .../org/apache/spark/rdd/BinaryFileRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +-- .../apache/spark/rdd/WholeTextFileRDD.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 3 +- .../scala/org/apache/spark/FileSuite.scala | 2 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 14 ++++---- .../scheduler/TaskResultGetterSuite.scala | 2 +- .../spark/util/AccumulatorV2Suite.scala | 6 ++-- .../util/MutableURLClassLoaderSuite.scala | 15 ++++----- .../spark/util/SizeEstimatorSuite.scala | 32 +++++++++---------- .../spark/util/collection/SorterSuite.scala | 16 +++++----- .../unsafe/sort/RadixSortSuite.scala | 2 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 2 +- .../kafka010/DirectKafkaInputDStream.scala | 2 +- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../spark/ml/linalg/MatrixUDTSuite.scala | 2 +- .../spark/ml/linalg/VectorUDTSuite.scala | 2 +- .../spark/repl/ExecutorClassLoaderSuite.scala | 12 +++---- .../cluster/SchedulerExtensionService.scala | 2 +- .../sql/catalyst/JavaTypeInference.scala | 4 +-- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++------ .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../expressions/codegen/CodeGenerator.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../encoders/ExpressionEncoderSuite.scala | 16 +++++----- .../expressions/ObjectExpressionsSuite.scala | 8 ++--- .../aggregate/PercentileSuite.scala | 4 +-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 4 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/tables.scala | 3 +- .../execution/datasources/DataSource.scala | 12 +++---- .../datasources/jdbc/DriverRegistry.scala | 2 +- .../streaming/state/StateStore.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 3 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../org/apache/spark/sql/JavaRowSuite.java | 14 ++++---- .../apache/spark/sql/JavaStringLength.java | 2 +- .../sql/ApproximatePercentileQuerySuite.scala | 6 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++---- .../org/apache/spark/sql/DatasetSuite.scala | 16 +++++----- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- .../execution/columnar/ColumnStatsSuite.scala | 4 +-- .../sources/RateStreamProviderSuite.scala | 8 +++-- .../sources/TextSocketStreamSuite.scala | 2 +- .../sources/v2/DataSourceV2UtilsSuite.scala | 2 +- .../sources/StreamingDataSourceV2Suite.scala | 8 +++-- .../org/apache/hive/service/cli/Column.java | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/HiveShim.scala | 2 +- .../apache/spark/sql/hive/TableReader.scala | 6 ++-- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/execution/HiveFileFormat.scala | 3 +- .../hive/execution/HiveTableScanExec.scala | 2 +- .../execution/ScriptTransformationExec.scala | 11 ++++--- .../sql/hive/HiveParquetMetastoreSuite.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 4 +-- .../spark/streaming/CheckpointSuite.scala | 3 +- 70 files changed, 190 insertions(+), 174 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72123f2232532..66038eeaea54f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -261,7 +261,7 @@ object SparkEnv extends Logging { // SparkConf, then one taking no arguments try { cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, new java.lang.Boolean(isDriver)) + .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) .asInstanceOf[T] } catch { case _: NoSuchMethodException => diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6259bead3ea88..2ab8add63efae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -43,7 +43,8 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).getConstructor(). + newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 537ab57f9664d..6e0a3f63988d4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -74,9 +74,9 @@ private[spark] object SerDe { jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null - case 'i' => new java.lang.Integer(readInt(dis)) - case 'd' => new java.lang.Double(readDouble(dis)) - case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'i' => java.lang.Integer.valueOf(readInt(dis)) + case 'd' => java.lang.Double.valueOf(readDouble(dis)) + case 'b' => java.lang.Boolean.valueOf(readBoolean(dis)) case 'c' => readString(dis) case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 88df7324a354a..0fc8c9bd789e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -829,7 +829,7 @@ private[spark] class SparkSubmit extends Logging { } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { - mainClass.newInstance().asInstanceOf[SparkApplication] + mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication] } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..7477e03bfaa76 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -91,7 +91,7 @@ class HadoopMapReduceCommitProtocol( private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - val format = context.getOutputFormatClass.newInstance() + val format = context.getOutputFormatClass.getConstructor().newInstance() // If OutputFormat is Configurable, we should set conf to it. format match { case c: Configurable => c.setConf(context.getConfiguration) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 9ebd0aa301592..3a58ea816937b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -256,7 +256,7 @@ class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) private def getOutputFormat(): OutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- @@ -379,7 +379,7 @@ class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfigura private def getOutputFormat(): NewOutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 3457a2632277d..bb7b434e9a113 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -179,7 +179,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Utils.classForName(classPath).newInstance() + val source = Utils.classForName(classPath).getConstructor().newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index a14bad47dfe10..039dbcbd5e035 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -41,7 +41,7 @@ private[spark] class BinaryFileRDD[T]( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2d66d25ba39fa..483de28d92ab7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -120,7 +120,7 @@ class NewHadoopRDD[K, V]( } override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(_conf) @@ -183,7 +183,7 @@ class NewHadoopRDD[K, V]( } } - private val format = inputFormatClass.newInstance + private val format = inputFormatClass.getConstructor().newInstance() format match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 9f3d0745c33c9..eada762b99c8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -44,7 +44,7 @@ private[spark] class WholeTextFileRDD( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 72427dd6ce4d4..218c84352ce88 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -132,7 +132,8 @@ class KryoSerializer(conf: SparkConf) .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrators - .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) + .map(Class.forName(_, true, classLoader).getConstructor(). + newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e35dd72521247..edae2f95fce33 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -237,7 +237,7 @@ private[spark] class BlockManager( val priorityClass = conf.get( "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) val clazz = Utils.classForName(priorityClass) - val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + val ret = clazz.getConstructor().newInstance().asInstanceOf[BlockReplicationPolicy] logInfo(s"Using $priorityClass for block replication policy") ret } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 93b5826f8a74b..a07eee6ad8a4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2430,7 +2430,8 @@ private[spark] object Utils extends Logging { "org.apache.spark.security.ShellBasedGroupsMappingProvider") if (groupProviderClassName != "") { try { - val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance. + val groupMappingServiceProvider = classForName(groupProviderClassName). + getConstructor().newInstance(). asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider] val currentUserGroups = groupMappingServiceProvider.getGroups(username) return currentUserGroups diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 34efcdf4bc886..df04a5ea1d99e 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -202,7 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val objs = sc.makeRDD(1 to 3).map { x => val loader = Thread.currentThread().getContextClassLoader - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() } val outputDir = new File(tempDir, "output").getAbsolutePath objs.saveAsObjectFile(outputDir) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 47af5c3320dd9..0ec359d1c94f3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -574,7 +574,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("saveNewAPIHadoopFile should call setConf if format is configurable") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) // No error, non-configurable formats still work pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") @@ -591,14 +591,14 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("The JobId on the driver and executors should be the same during the commit") { // Create more than one rdd to mimic stageId not equal to rddId val pairs = sc.parallelize(Array((1, 2), (2, 3)), 2) - .map { p => (new Integer(p._1 + 1), new Integer(p._2 + 1)) } + .map { p => (Integer.valueOf(p._1 + 1), Integer.valueOf(p._2 + 1)) } .filter { p => p._1 > 0 } pairs.saveAsNewAPIHadoopFile[YetAnotherFakeFormat]("ignored") assert(JobID.jobid != -1) } test("saveAsHadoopFile should respect configured output committers") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) val conf = new JobConf() conf.setOutputCommitter(classOf[FakeOutputCommitter]) @@ -610,7 +610,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) FakeWriterWithCallback.calledBy = "" FakeWriterWithCallback.exception = null @@ -625,7 +625,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() FakeWriterWithCallback.calledBy = "" @@ -643,7 +643,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) def saveRddWithPath(path: String): Unit = { val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) @@ -671,7 +671,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() conf.setOutputKeyClass(classOf[Integer]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 1bddba8f6c82b..f8eb8bd71c170 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -194,7 +194,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // jar. sc = new SparkContext("local", "test", conf) val rdd = sc.parallelize(Seq(1), 1).map { _ => - val exc = excClass.newInstance().asInstanceOf[Exception] + val exc = excClass.getConstructor().newInstance().asInstanceOf[Exception] throw exc } diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index 621399af731f7..172bebbfec61d 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -40,7 +40,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Long(2)) + acc.add(java.lang.Long.valueOf(2)) assert(acc.count == 3) assert(acc.sum == 3) assert(acc.avg == 1.0) @@ -73,7 +73,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Double(2.0)) + acc.add(java.lang.Double.valueOf(2.0)) assert(acc.count == 3) assert(acc.sum == 3.0) assert(acc.avg == 1.0) @@ -96,7 +96,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.value.contains(0.0)) assert(!acc.isZero) - acc.add(new java.lang.Double(1.0)) + acc.add(java.lang.Double.valueOf(1.0)) val acc2 = acc.copyAndReset() assert(acc2.value.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index f6ac89fc2742a..8d844bd08771c 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.scalatest.Matchers._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} @@ -46,10 +45,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass = classLoader.loadClass("FakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") - val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass2").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -58,10 +57,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new MutableURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass = classLoader.loadClass("FakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") - val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass1").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -70,7 +69,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass3").newInstance() + val fakeClass = classLoader.loadClass("FakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") classLoader.close() @@ -81,7 +80,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("FakeClassDoesNotExist").newInstance() + classLoader.loadClass("FakeClassDoesNotExist").getConstructor().newInstance() } classLoader.close() parentLoader.close() @@ -137,7 +136,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader // scalastyle:off classforname - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() // scalastyle:on classforname Seq().iterator }.count() diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 2695295d451d5..63f9f82adf3e0 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -92,14 +92,14 @@ class SizeEstimatorSuite } test("primitive wrapper objects") { - assertResult(16)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(16)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(16)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(16)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(16)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding") { @@ -202,14 +202,14 @@ class SizeEstimatorSuite assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) // primitive wrapper classes - assertResult(24)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(24)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(24)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(24)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding on 64-bit VM without useCompressedOops") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 65bf857e22c02..46a05e2ba798b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.lang.{Float => JFloat, Integer => JInteger} +import java.lang.{Float => JFloat} import java.util.{Arrays, Comparator} import org.apache.spark.SparkFunSuite @@ -48,7 +48,7 @@ class SorterSuite extends SparkFunSuite with Logging { // alternate. Keys are random doubles, values are ordinals from 0 to length. val keys = Array.tabulate[Double](5000) { i => rand.nextDouble() } val keyValueArray = Array.tabulate[Number](10000) { i => - if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + if (i % 2 == 0) keys(i / 2) else Integer.valueOf(i / 2) } // Map from generated keys to values, to verify correctness later @@ -112,7 +112,7 @@ class SorterSuite extends SparkFunSuite with Logging { // Test our key-value pairs where each element is a Tuple2[Float, Integer]. val kvTuples = Array.tabulate(numElements) { i => - (new JFloat(rand.nextFloat()), new JInteger(i)) + (JFloat.valueOf(rand.nextFloat()), Integer.valueOf(i)) } val kvTupleArray = new Array[AnyRef](numElements) @@ -167,23 +167,23 @@ class SorterSuite extends SparkFunSuite with Logging { val ints = Array.fill(numElements)(rand.nextInt()) val intObjects = { - val data = new Array[JInteger](numElements) + val data = new Array[Integer](numElements) var i = 0 while (i < numElements) { - data(i) = new JInteger(ints(i)) + data(i) = Integer.valueOf(ints(i)) i += 1 } data } - val intObjectArray = new Array[JInteger](numElements) + val intObjectArray = new Array[Integer](numElements) val prepareIntObjectArray = () => { System.arraycopy(intObjects, 0, intObjectArray, 0, numElements) } runExperiment("Java Arrays.sort() on non-primitive int array")({ - Arrays.sort(intObjectArray, new Comparator[JInteger] { - override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y) + Arrays.sort(intObjectArray, new Comparator[Integer] { + override def compare(x: Integer, y: Integer): Int = x.compareTo(y) }) }, prepareIntObjectArray) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..d570630c1a095 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -78,7 +78,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => JLong.valueOf(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 4fea2cb969446..8d6cca8e48c3d 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -508,7 +508,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() assert( union2 - .map(x => new java.lang.Double(x(0).toString)) + .map(x => java.lang.Double.valueOf(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28c9853bfea9c..5034bd73d6e74 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -510,7 +510,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, java.lang.Integer.valueOf(1)) // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index ba4009ef08856..224f41a683955 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -70,7 +70,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( @transient private var kc: Consumer[K, V] = null def consumer(): Consumer[K, V] = this.synchronized { if (null == kc) { - kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava) + kc = consumerStrategy.onStart(currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).asJava) } kc } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a0ac26a34d8c8..d985f8ca1ecc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -256,7 +256,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { s"Multiple writers found for $source+$stageName, try using the class name of the writer") } if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { - val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + val writer = writerCls.getConstructor().newInstance().asInstanceOf[MLWriterFormat] writer.write(path, sparkSession, optionMap, stage) } else { throw new SparkException(s"ML source $source is not a valid MLWriterFormat") diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala index bdceba7887cac..8371c33a209dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala @@ -31,7 +31,7 @@ class MatrixUDTSuite extends SparkFunSuite { val sm3 = dm3.toSparse for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) { - val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[MatrixUDT] assert(m === udt.deserialize(udt.serialize(m))) assert(udt.typeName == "matrix") diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 6ddb12cb76aac..67c64f762b25e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -31,7 +31,7 @@ class VectorUDTSuite extends SparkFunSuite { val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) for (v <- Seq(dv1, dv2, sv1, sv2)) { - val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[VectorUDT] assert(v === udt.deserialize(udt.serialize(v))) assert(udt.typeName == "vector") diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index e5e2094368fb0..ac528ecb829b0 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -126,7 +126,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") } @@ -134,7 +134,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) - val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -142,7 +142,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -151,7 +151,7 @@ class ExecutorClassLoaderSuite val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } @@ -202,11 +202,11 @@ class ExecutorClassLoaderSuite val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", getClass().getClassLoader(), false) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala index 4ed285230ff81..7d15f0e2fbac8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -107,7 +107,7 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass => val instance = Utils.classForName(sClass) - .newInstance() + .getConstructor().newInstance() .asInstanceOf[SchedulerExtensionService] // bind this service instance.start(binding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8ef8b2be6939c..311060e5961cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -73,10 +73,10 @@ object JavaTypeInference { : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true) case c: Class[_] if UDTRegistration.exists(c.getName) => - val udt = UDTRegistration.getUDTFor(c.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance() .asInstanceOf[UserDefinedType[_ >: Null]] (udt, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 912744eab6a3a..64ea236532839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -357,7 +357,8 @@ object ScalaReflection extends ScalaReflection { ) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -365,8 +366,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -601,7 +602,7 @@ object ScalaReflection extends ScalaReflection { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -609,8 +610,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -721,11 +722,12 @@ object ScalaReflection extends ScalaReflection { // Null type would wrongly match the first of them, which is Option as of now case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() Schema(udt, nullable = true) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c11b444212946..b6771ec4dffe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1134,7 +1134,8 @@ class SessionCatalog( if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .newInstance(input, + clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) .asInstanceOf[ImplicitCastInputTypes] // Check input argument size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b868a0f4fa284..7c8f7cd4315b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1305,7 +1305,7 @@ object CodeGenerator extends Logging { throw new CompileException(msg, e.getLocation) } - (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) + (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e53628d11ccf3..33fc4b9480126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -180,7 +180,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + Utils.classForName(udtClass).getConstructor().newInstance().asInstanceOf[UserDefinedType[_]] // Python UDT case JSortedObject( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e9b100b3b30db..be8fd90c4c52a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,13 +128,13 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(-3.7f, "primitive float") encodeDecodeTest(-3.7, "primitive double") - encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + encodeDecodeTest(java.lang.Boolean.FALSE, "boxed boolean") + encodeDecodeTest(java.lang.Byte.valueOf(-3: Byte), "boxed byte") + encodeDecodeTest(java.lang.Short.valueOf(-3: Short), "boxed short") + encodeDecodeTest(java.lang.Integer.valueOf(-3), "boxed int") + encodeDecodeTest(java.lang.Long.valueOf(-3L), "boxed long") + encodeDecodeTest(java.lang.Float.valueOf(-3.7f), "boxed float") + encodeDecodeTest(java.lang.Double.valueOf(-3.7), "boxed double") encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") @@ -224,7 +224,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes productTest( RepeatedData( Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), + Seq(Integer.valueOf(1), null, Integer.valueOf(2)), Map(1 -> 2L), Map(1 -> null), PrimitiveData(1, 1, 1, 1, 1, 1, true))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d145fd0aaba47..16842c1bcc8cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -307,7 +307,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val expected = serializer.newInstance().serialize(new Integer(1)).array() + val expected = serializer.newInstance().serialize(Integer.valueOf(1)).array() val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) @@ -384,9 +384,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val input = serializer.newInstance().serialize(new Integer(1)).array() + val input = serializer.newInstance().serialize(Integer.valueOf(1)).array() val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) - checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, Integer.valueOf(1), InternalRow.fromSeq(Seq(input))) checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } @@ -575,7 +575,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // NULL key test val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( - null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + null.asInstanceOf[java.lang.Integer] -> "v0", java.lang.Integer.valueOf(1) -> "v1") val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { { put(null, "v0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 294fce8e9a10f..63c7b42978025 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -41,9 +41,9 @@ class PercentileSuite extends SparkFunSuite { val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) - // Check non-empty buffer serializa and deserialize. + // Check non-empty buffer serialize and deserialize. data.foreach { key => - buffer.changeValue(new Integer(key), 1L, _ + 1L) + buffer.changeValue(Integer.valueOf(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 02ffc940184db..df18623e42a02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5a28870f5d3c2..1b4998f94b25d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val source = cls.newInstance().asInstanceOf[DataSourceV2] + val source = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] source match { case provider: BatchWriteSupportProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 71f967a59d77e..c0727e844a1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1144,7 +1144,7 @@ object SparkSession extends Logging { val extensionConfClassName = extensionOption.get try { val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() + val extensionConf = extensionConfClass.getConstructor().newInstance() .asInstanceOf[SparkSessionExtensions => Unit] extensionConf(extensions) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index aa3a6c3bf122f..84da097be53c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -670,7 +670,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className") } else { try { - val udf = clazz.newInstance() + val udf = clazz.getConstructor().newInstance() val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { @@ -727,7 +727,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") } - val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..becb05cf72aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -111,7 +111,7 @@ private[sql] object SQLUtils extends Logging { private[this] def doConversion(data: Object, dataType: DataType): Object = { data match { case d: java.lang.Double if dataType == FloatType => - new java.lang.Float(d) + java.lang.Float.valueOf(d.toFloat) // Scala Map is the only allowed external type of map type in Row. case m: java.util.Map[_, _] => m.asScala case _ => data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 823dc0d5ed387..e2cd40906f401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -231,7 +231,8 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match { + DataSource.lookupDataSource(catalogTable.provider.get, conf). + getConstructor().newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ce3bc3dd48327..795a6d0b6b040 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -204,7 +204,7 @@ case class DataSource( /** Returns the name and schema of the source that can be used to continually read data. */ private def sourceSchema(): SourceInfo = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions) @@ -250,7 +250,7 @@ case class DataSource( /** Returns a source that can be used to continually read data. */ def createSource(metadataPath: String): Source = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => s.createSource( sparkSession.sqlContext, @@ -279,7 +279,7 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(outputMode: OutputMode): Sink = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode) @@ -310,7 +310,7 @@ case class DataSource( * that files already exist, we don't need to check them again. */ def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { - val relation = (providingClass.newInstance(), userSpecifiedSchema) match { + val relation = (providingClass.getConstructor().newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) @@ -479,7 +479,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) @@ -516,7 +516,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 1723596de1db2..530d836d9fde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -50,7 +50,7 @@ object DriverRegistry extends Logging { } else { synchronized { if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + val wrapper = new DriverWrapper(cls.getConstructor().newInstance().asInstanceOf[Driver]) DriverManager.registerDriver(wrapper) wrapperMap(className) = wrapper logTrace(s"Wrapper for $className registered") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d3313b8a315c9..7d785aa09cd9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -213,7 +213,7 @@ object StateStoreProvider { */ def create(providerClassName: String): StateStoreProvider = { val providerClass = Utils.classForName(providerClassName) - providerClass.newInstance().asInstanceOf[StateStoreProvider] + providerClass.getConstructor().newInstance().asInstanceOf[StateStoreProvider] } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 20c84305776ae..bf6021e692382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -158,7 +158,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo "read files of Hive data source directly.") } - val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). + getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 4a8c7fdb58ff1..b36a8f3f6f15b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -307,7 +307,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") var options = extraOptions.toMap - val sink = ds.newInstance() match { + val sink = ds.getConstructor().newInstance() match { case w: StreamingWriteSupportProvider if !disabledSources.contains(w.getClass.getCanonicalName) => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 3ab4db2a035d3..ca78d6489ef5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -67,20 +67,20 @@ public void setUp() { public void constructSimpleRow() { Row simpleRow = RowFactory.create( byteValue, // ByteType - new Byte(byteValue), + Byte.valueOf(byteValue), shortValue, // ShortType - new Short(shortValue), + Short.valueOf(shortValue), intValue, // IntegerType - new Integer(intValue), + Integer.valueOf(intValue), longValue, // LongType - new Long(longValue), + Long.valueOf(longValue), floatValue, // FloatType - new Float(floatValue), + Float.valueOf(floatValue), doubleValue, // DoubleType - new Double(doubleValue), + Double.valueOf(doubleValue), decimalValue, // DecimalType booleanValue, // BooleanType - new Boolean(booleanValue), + Boolean.valueOf(booleanValue), stringValue, // StringType binaryValue, // BinaryType dateValue, // DateType diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java index b90224f2ae397..5955eabe496df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -25,6 +25,6 @@ public class JavaStringLength implements UDF1 { @Override public Integer call(String str) throws Exception { - return new Integer(str.length()); + return Integer.valueOf(str.length()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index d635912cf7205..52708f5fe4108 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -208,7 +208,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { test("percentile_approx(col, ...), input rows contains null, with out group by") { withTempView(table) { - (1 to 1000).map(new Integer(_)).flatMap(Seq(null: Integer, _)).toDF("col") + (1 to 1000).map(Integer.valueOf(_)).flatMap(Seq(null: Integer, _)).toDF("col") .createOrReplaceTempView(table) checkAnswer( spark.sql( @@ -226,8 +226,8 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { withTempView(table) { val rand = new java.util.Random() (1 to 1000) - .map(new Integer(_)) - .map(v => (new Integer(v % 2), v)) + .map(Integer.valueOf(_)) + .map(v => (Integer.valueOf(v % 2), v)) // Add some nulls .flatMap(Seq(_, (null: Integer, null: Integer))) .toDF("key", "value").createOrReplaceTempView(table) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index edde9bfd088cf..2bb18f48e0ae2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1986,7 +1986,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11725: correctly handle null inputs for ScalaUDF") { val df = sparkContext.parallelize(Seq( - new java.lang.Integer(22) -> "John", + java.lang.Integer.valueOf(22) -> "John", null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") // passing null into the UDF that could handle it @@ -2219,9 +2219,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: no change on nullability in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) @@ -2236,9 +2236,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: set nullability to false in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 82d3b22a48670..75d06510376ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -697,15 +697,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-11894: Incorrect results are returned when using null") { val nullInt = null.asInstanceOf[java.lang.Integer] - val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() - val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() checkDataset( ds1.joinWith(ds2, lit(true), "cross"), ((nullInt, "1"), (nullInt, "1")), - ((nullInt, "1"), (new java.lang.Integer(22), "2")), - ((new java.lang.Integer(22), "2"), (nullInt, "1")), - ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + ((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")), + ((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")), + ((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), "2"))) } test("change encoder with compatible schema") { @@ -881,7 +881,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.rdd.map(r => r.id).count === 2) assert(ds2.rdd.map(r => r.id).count === 2) - val ds3 = ds.map(g => new java.lang.Long(g.id)) + val ds3 = ds.map(g => java.lang.Long.valueOf(g.id)) assert(ds3.rdd.map(r => r).count === 2) } @@ -1499,7 +1499,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) withTempPath { path => - Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath) + Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath) // If the primitive values are from files, we need to do runtime null check. val ds = spark.read.parquet(path.getCanonicalPath).as[Int] intercept[NullPointerException](ds.collect()) @@ -1553,7 +1553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") checkAnswer(df.where('city === 'X'), Seq(Row("X"))) checkAnswer( - df.where($"city".contains(new java.lang.Character('A'))), + df.where($"city".contains(java.lang.Character.valueOf('A'))), Seq(Row("Amsterdam"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 3b301a4f8144a..20dcefa7e3cad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -413,7 +413,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { withTable("t") { - Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), null)) + Seq((null, Integer.valueOf(1), "x"), ("M", null, "y"), ("N", Integer.valueOf(3), null)) .toDF("a", "b", "c").write.format("json").saveAsTable("t") spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) val df = spark.sql("SELECT f(a, b, c) FROM t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index d4e7e362c6c8c..3121b7e99c99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -39,7 +39,7 @@ class ColumnStatsSuite extends SparkFunSuite { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } @@ -48,7 +48,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af873c2e5..be3efed714030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -53,7 +53,8 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => val readSupport = ds.createMicroBatchReadSupport( temp.getCanonicalPath, DataSourceOptions.empty()) @@ -66,7 +67,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => @@ -320,7 +321,8 @@ class RateSourceSuite extends StreamTest { } test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: ContinuousReadSupportProvider => val readSupport = ds.createContinuousReadSupport( "", DataSourceOptions.empty()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e5ebc70..635ea6fca649c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -84,7 +84,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 4911e3225552d..f903c17923d0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,7 +33,7 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3a0e780a73915..31fce46c2daba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -261,7 +261,7 @@ class StreamingDataSourceV2Suite extends StreamTest { ).foreach { case (source, trigger) => test(s"SPARK-25460: session options are respected in structured streaming sources - $source") { // `keyPrefix` and `shortName` are the same in this test case - val readSource = source.newInstance().shortName() + val readSource = source.getConstructor().newInstance().shortName() val writeSource = "fake-write-microbatch-continuous" val readOptionName = "optionA" @@ -299,8 +299,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf). + getConstructor().newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). + getConstructor().newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index adb269aa235ea..26d0f718f383a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : new Double(field.toString()); + doubleVars()[size] = field == null ? 0 : Double.valueOf(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d047953327958..5823548a8063c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -124,7 +124,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) - val fileFormat = fileFormatClass.newInstance() + val fileFormat = fileFormatClass.getConstructor().newInstance() val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 11afe1af32809..c9fc3d4a02c4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -217,7 +217,7 @@ private[hive] object HiveShim { instance.asInstanceOf[UDFType] } else { val func = Utils.getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + .loadClass(functionClassName).getConstructor().newInstance().asInstanceOf[UDFType] if (!func.isInstanceOf[UDF]) { // We cache the function if it's no the Simple UDF, // as we always have to create new instance for Simple UDF diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 9443fbb4330a5..536bc4a3f4ec4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -132,7 +132,7 @@ class HadoopTableReader( val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value - val deserializer = deserializerClass.newInstance() + val deserializer = deserializerClass.getConstructor().newInstance() deserializer.initialize(hconf, localTableDesc.getProperties) HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } @@ -245,7 +245,7 @@ class HadoopTableReader( val localTableDesc = tableDesc createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val deserializer = localDeserializer.newInstance() + val deserializer = localDeserializer.getConstructor().newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema // information) may be defined in table properties. Here we should merge table properties // and partition properties before initializing the deserializer. Note that partition @@ -257,7 +257,7 @@ class HadoopTableReader( } deserializer.initialize(hconf, props) // get the table deserializer - val tableSerDe = localTableDesc.getDeserializerClass.newInstance() + val tableSerDe = localTableDesc.getDeserializerClass.getConstructor().newInstance() tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bc9d4cd7f4181..4d484904d2c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -987,7 +987,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { part: JList[String], deleteData: Boolean, purge: Boolean): Unit = { - val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] + val dropOptions = dropOptionsClass.getConstructor().newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index 4a7cd6901923b..d8d2a80e0e8b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -115,7 +115,8 @@ class HiveOutputWriter( private def tableDesc = fileSinkConf.getTableInfo private val serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + val serializer = tableDesc.getDeserializerClass.getConstructor(). + newInstance().asInstanceOf[Serializer] serializer.initialize(jobConf, tableDesc.getProperties) serializer } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 92c6632ad7863..fa940fe73bd13 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -120,7 +120,7 @@ case class HiveTableScanExec( HiveShim.appendReadColumns(hiveConf, neededColumnIDs, output.map(_.name)) - val deserializer = tableDesc.getDeserializerClass.newInstance + val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance() deserializer.initialize(hiveConf, tableDesc.getProperties) // Specifies types and object inspectors of columns to be scanned. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index 3328400b214fb..7b35a5f920ae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -123,7 +123,7 @@ case class ScriptTransformationExec( var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().newInstance + outputSerde.getSerializedClass().getConstructor().newInstance() } else { null } @@ -404,7 +404,8 @@ case class HiveScriptIOSchema ( columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] + val serde = Utils.classForName(serdeClassName).getConstructor(). + newInstance().asInstanceOf[AbstractSerDe] val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") @@ -424,7 +425,8 @@ case class HiveScriptIOSchema ( inputStream: InputStream, conf: Configuration): Option[RecordReader] = { recordReaderClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordReader] val props = new Properties() // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 @@ -436,7 +438,8 @@ case class HiveScriptIOSchema ( def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { recordWriterClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordWriter] instance.initialize(outputStream, conf) instance } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala index 0d4f040156084..68a0c1213ec20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala @@ -152,7 +152,7 @@ class HiveParquetMetastoreSuite extends ParquetPartitioningTest { } (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").createOrReplaceTempView("jt") - (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") + (1 to 10).map(i => Tuple1(Seq(Integer.valueOf(i), null))).toDF("a") .createOrReplaceTempView("jt_array") assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8d83dc8a8fc04..6f0b46b6a4cb3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -49,11 +49,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).getConstructor().newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).getConstructor().newInstance().asInstanceOf[Clock] } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 19b621f11759d..2332ee2ab9de1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -808,7 +808,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // visible to mutableURLClassLoader val loader = new MutableURLClassLoader( Array(jar), appClassLoader) - assert(loader.loadClass("testClz").newInstance().toString == "testStringValue") + assert(loader.loadClass("testClz").getConstructor().newInstance().toString === + "testStringValue") // create and serialize Array[testClz] // scalastyle:off classforname From 6cd23482d1ae8c6a9fe9817ed51ee2a039d46649 Mon Sep 17 00:00:00 2001 From: Patrick Brown Date: Sat, 10 Nov 2018 12:51:24 -0600 Subject: [PATCH 019/145] [SPARK-25839][CORE] Implement use of KryoPool in KryoSerializer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? * Implement (optional) use of KryoPool in KryoSerializer, an alternative to the existing implementation of caching a Kryo instance inside KryoSerializerInstance * Add config key & documentation of spark.kryo.pool in order to turn this on * Add benchmark KryoSerializerBenchmark to compare new and old implementation * Add results of benchmark ## How was this patch tested? Added new tests inside KryoSerializerSuite to test the pool implementation as well as added the pool option to the existing regression testing for SPARK-7766 This is my original work and I license the work to the project under the project’s open source license. Closes #22855 from patrickbrownsync/kryo-pool. Authored-by: Patrick Brown Signed-off-by: Sean Owen --- .../KryoSerializerBenchmark-results.txt | 12 +++ .../spark/serializer/KryoSerializer.scala | 72 ++++++++++++--- .../spark/benchmark/BenchmarkBase.scala | 7 ++ .../serializer/KryoSerializerBenchmark.scala | 90 +++++++++++++++++++ .../serializer/KryoSerializerSuite.scala | 66 ++++++++++++-- 5 files changed, 230 insertions(+), 17 deletions(-) create mode 100644 core/benchmarks/KryoSerializerBenchmark-results.txt create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt new file mode 100644 index 0000000000000..c3ce336d93241 --- /dev/null +++ b/core/benchmarks/KryoSerializerBenchmark-results.txt @@ -0,0 +1,12 @@ +================================================================================================ +Benchmark KryoPool vs "pool of 1" +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz +Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X +KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X + + diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 218c84352ce88..3795d5c3b38e3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} +import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} @@ -84,6 +85,7 @@ class KryoSerializer(conf: SparkConf) private val avroSchemas = conf.getAvroSchema // whether to use unsafe based IO for serialization private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false) + private val usePool = conf.getBoolean("spark.kryo.pool", true) def newKryoOutput(): KryoOutput = if (useUnsafe) { @@ -92,6 +94,36 @@ class KryoSerializer(conf: SparkConf) new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) } + @transient + private lazy val factory: KryoFactory = new KryoFactory() { + override def create: Kryo = { + newKryo() + } + } + + private class PoolWrapper extends KryoPool { + private var pool: KryoPool = getPool + + override def borrow(): Kryo = pool.borrow() + + override def release(kryo: Kryo): Unit = pool.release(kryo) + + override def run[T](kryoCallback: KryoCallback[T]): T = pool.run(kryoCallback) + + def reset(): Unit = { + pool = getPool + } + + private def getPool: KryoPool = { + new KryoPool.Builder(factory).softReferences.build + } + } + + @transient + private lazy val internalPool = new PoolWrapper + + def pool: KryoPool = internalPool + def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() @@ -215,8 +247,14 @@ class KryoSerializer(conf: SparkConf) kryo } + override def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + super.setDefaultClassLoader(classLoader) + internalPool.reset() + this + } + override def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this, useUnsafe) + new KryoSerializerInstance(this, useUnsafe, usePool) } private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { @@ -299,7 +337,8 @@ class KryoDeserializationStream( } } -private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean) +private[spark] class KryoSerializerInstance( + ks: KryoSerializer, useUnsafe: Boolean, usePool: Boolean) extends SerializerInstance { /** * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do @@ -307,22 +346,29 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are * not synchronized. */ - @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + @Nullable private[this] var cachedKryo: Kryo = if (usePool) null else borrowKryo() /** * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; * otherwise, it allocates a new instance. */ private[serializer] def borrowKryo(): Kryo = { - if (cachedKryo != null) { - val kryo = cachedKryo - // As a defensive measure, call reset() to clear any Kryo state that might have been modified - // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + if (usePool) { + val kryo = ks.pool.borrow() kryo.reset() - cachedKryo = null kryo } else { - ks.newKryo() + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have + // been modified by the last operation to borrow this instance + // (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } } } @@ -332,8 +378,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * re-use. */ private[serializer] def releaseKryo(kryo: Kryo): Unit = { - if (cachedKryo == null) { - cachedKryo = kryo + if (usePool) { + ks.pool.release(kryo) + } else { + if (cachedKryo == null) { + cachedKryo = kryo + } } } diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 24e596e1ecdaf..a6666db4e95c3 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -58,5 +58,12 @@ abstract class BenchmarkBase { o.close() } } + + afterAll() } + + /** + * Any shutdown code to ensure a clean shutdown + */ + def afterAll(): Unit = {} } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala new file mode 100644 index 0000000000000..2a15c6f6a2d96 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import scala.concurrent._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.ThreadUtils + +/** + * Benchmark for KryoPool vs old "pool of 1". + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/KryoSerializerBenchmark-results.txt". + * }}} + */ +object KryoSerializerBenchmark extends BenchmarkBase { + + var sc: SparkContext = null + val N = 500 + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val name = "Benchmark KryoPool vs old\"pool of 1\" implementation" + runBenchmark(name) { + val benchmark = new Benchmark(name, N, 10, output = output) + Seq(true, false).foreach(usePool => run(usePool, benchmark)) + benchmark.run() + } + } + + private def run(usePool: Boolean, benchmark: Benchmark): Unit = { + lazy val sc = createSparkContext(usePool) + + benchmark.addCase(s"KryoPool:$usePool") { _ => + val futures = for (_ <- 0 until N) yield { + Future { + sc.parallelize(0 until 10).map(i => i + 1).count() + } + } + + val future = Future.sequence(futures) + + ThreadUtils.awaitResult(future, 10.minutes) + } + } + + def createSparkContext(usePool: Boolean): SparkContext = { + val conf = new SparkConf() + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + conf.set("spark.kryo.pool", usePool.toString) + + if (sc != null) { + sc.stop() + } + + sc = new SparkContext("local-cluster[4,1,1024]", "test", conf) + sc + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index ac25bcef54349..84af73b08d3e7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -31,7 +34,7 @@ import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -308,7 +311,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrator", "this.class.does.not.exist") - val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance().serialize(1)) assert(thrown.getMessage.contains("Failed to register classes with Kryo")) } @@ -431,9 +434,11 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { ser.deserialize[HashMap[Int, List[String]]](serializedMap) } - private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { + private def testSerializerInstanceReuse( + autoReset: Boolean, referenceTracking: Boolean, usePool: Boolean): Unit = { val conf = new SparkConf(loadDefaults = false) .set("spark.kryo.referenceTracking", referenceTracking.toString) + .set("spark.kryo.pool", usePool.toString) if (!autoReset) { conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) } @@ -456,9 +461,58 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { // Regression test for SPARK-7766, an issue where disabling auto-reset and enabling // reference-tracking would lead to corrupted output when serializer instances are re-used - for (referenceTracking <- Set(true, false); autoReset <- Set(true, false)) { - test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking") { - testSerializerInstanceReuse(autoReset = autoReset, referenceTracking = referenceTracking) + for { + referenceTracking <- Seq(true, false) + autoReset <- Seq(true, false) + usePool <- Seq(true, false) + } { + test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking" + + s", usePool = $usePool") { + testSerializerInstanceReuse( + autoReset, referenceTracking, usePool) + } + } + + test("SPARK-25839 KryoPool implementation works correctly in multi-threaded environment") { + implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(4)) + + val ser = new KryoSerializer(conf.clone.set("spark.kryo.pool", "true")) + + val tests = mutable.ListBuffer[Future[Boolean]]() + + def check[T: ClassTag](t: T) { + tests += Future { + val serializerInstance = ser.newInstance() + serializerInstance.deserialize[T](serializerInstance.serialize(t)) === t + } + } + + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) + + tests.foreach { f => + assert(ThreadUtils.awaitResult(f, 10.seconds)) } } } From a3ba3a899b3b43958820dc82fcdd3a8b28653bcb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 11 Nov 2018 14:05:19 +0800 Subject: [PATCH 020/145] [INFRA] Close stale PRs Closes https://github.com/apache/spark/pull/21766 Closes https://github.com/apache/spark/pull/21679 Closes https://github.com/apache/spark/pull/21161 Closes https://github.com/apache/spark/pull/20846 Closes https://github.com/apache/spark/pull/19434 Closes https://github.com/apache/spark/pull/18080 Closes https://github.com/apache/spark/pull/17648 Closes https://github.com/apache/spark/pull/17169 Add: Closes #22813 Closes #21994 Closes #22005 Closes #22463 Add: Closes #15899 Add: Closes #22539 Closes #21868 Closes #21514 Closes #21402 Closes #21322 Closes #21257 Closes #20163 Closes #19691 Closes #18697 Closes #18636 Closes #17176 Closes #23001 from wangyum/CloseStalePRs. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon From aec0af4a952df2957e21d39d1e0546a36ab7ab86 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 11 Nov 2018 21:01:29 +0800 Subject: [PATCH 021/145] [SPARK-25972][PYTHON] Missed JSON options in streaming.py ## What changes were proposed in this pull request? Added JSON options for `json()` in streaming.py that are presented in the similar method in readwriter.py. In particular, missed options are `dropFieldIfAllNull` and `encoding`. Closes #22973 from MaxGekk/streaming-missed-options. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/streaming.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 02b14ea187cba..58ca7b83e5b2b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None, + dropFieldIfAllNull=None, encoding=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -472,6 +473,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -486,7 +494,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale, + dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: From 510ec77a601db1c0fa338dd76a0ea7af63441fd3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 11 Nov 2018 09:21:40 -0600 Subject: [PATCH 022/145] [SPARK-19714][DOCS] Clarify Bucketizer handling of invalid input ## What changes were proposed in this pull request? Clarify Bucketizer handleInvalid docs. Just a resubmit of https://github.com/apache/spark/pull/17169 ## How was this patch tested? N/A Closes #23003 from srowen/SPARK-19714. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 6 ++++-- python/pyspark/ml/feature.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f99649f7fa164..0b989b0d7d253 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,7 +89,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are 'skip' (filter out rows with + * Param for how to handle invalid entries containing NaN values. Values outside the splits + * will always be treated as errors. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special * additional bucket). Note that in the multiple column case, the invalid handling is applied * to all columns. That said for 'error' it will throw an error if any invalids are found in @@ -99,7 +100,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String */ @Since("2.1.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "how to handle invalid entries. Options are skip (filter out rows with invalid values), " + + "how to handle invalid entries containing NaN values. Values outside the splits will always " + + "be treated as errorsOptions are skip (filter out rows with invalid values), " + "error (throw an error), or keep (keep invalid values in a special additional bucket).", ParamValidators.inArray(Bucketizer.supportedHandleInvalids)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eccb7acae5b98..3d23700242594 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -361,8 +361,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, "splits specified will be treated as errors.", typeConverter=TypeConverters.toListFloat) - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + - "Options are 'skip' (filter out rows with invalid values), " + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries " + "containing NaN values. Values outside the splits will always be treated " + "as errors. Options are 'skip' (filter out rows with invalid values), " + "'error' (throw an error), or 'keep' (keep invalid values in a special " + "additional bucket).", typeConverter=TypeConverters.toString) From d0ae48497c093cef23fb95c10aa448b3b498c758 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 12 Nov 2018 15:16:15 +0800 Subject: [PATCH 023/145] [SPARK-25949][SQL] Add test for PullOutPythonUDFInJoinCondition ## What changes were proposed in this pull request? As comment in https://github.com/apache/spark/pull/22326#issuecomment-424923967, we test the new added optimizer rule by end-to-end test in python side, need to add suites under `org.apache.spark.sql.catalyst.optimizer` like other optimizer rules. ## How was this patch tested? new added UT Closes #22955 from xuanyuanking/SPARK-25949. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- ...PullOutPythonUDFInJoinConditionSuite.scala | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala new file mode 100644 index 0000000000000..d3867f2b6bd0e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.scalatest.Matchers._ + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.types.BooleanType + +class PullOutPythonUDFInJoinConditionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :: + Batch("Check Cartesian Products", Once, + CheckCartesianProducts) :: Nil + } + + val testRelationLeft = LocalRelation('a.int, 'b.int) + val testRelationRight = LocalRelation('c.int, 'd.int) + + // Dummy python UDF for testing. Unable to execute. + val pythonUDF = PythonUDF("pythonUDF", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + + val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) + + private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = { + // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false + val exception = intercept[AnalysisException] { + Optimize.execute(query.analyze) + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + // pull out the python udf while set spark.sql.crossJoin.enabled=true + withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + } + + test("inner join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("left semi join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("python udf and common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF && 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("python udf or common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF || 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("pull out whole complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(condition).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("partial pull out complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("throw an exception for not support join type") { + for (joinType <- unsupportedJoinTypes) { + val thrownException = the [AnalysisException] thrownBy { + val query = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(pythonUDF)) + Optimize.execute(query.analyze) + } + assert(thrownException.message.contentEquals( + s"Using PythonUDF in join condition of join type $joinType is not supported.")) + } + } +} + From 0ba9715c7d1ef1eabc276320c81f0acb20bafb59 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 11 Nov 2018 23:21:47 -0800 Subject: [PATCH 024/145] [SPARK-26005][SQL] Upgrade ANTRL from 4.7 to 4.7.1 ## What changes were proposed in this pull request? Based on the release description of ANTRL 4.7.1., https://github.com/antlr/antlr4/releases, let us upgrade our parser to 4.7.1. ## How was this patch tested? N/A Closes #23005 from gatorsmile/upgradeAntlr4.7. Authored-by: gatorsmile Signed-off-by: gatorsmile --- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 15a570908cc9a..a3030bd601534 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -5,7 +5,7 @@ activation-1.1.1.jar aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 6d9191a4abb4c..4354e76b521fc 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -7,7 +7,7 @@ activation-1.1.1.jar aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/pom.xml b/pom.xml index a08b7fda33387..f58959b665e1b 100644 --- a/pom.xml +++ b/pom.xml @@ -174,7 +174,7 @@ 3.5.2 3.0.0 0.9.3 - 4.7 + 4.7.1 1.1 2.52.0 2.6 - 3.5 + 3.8.1 3.2.10 3.0.10 2.22.2 @@ -2016,7 +2016,7 @@ net.alchim31.maven scala-maven-plugin - 3.2.2 + 3.4.4 eclipse-add-source @@ -2281,7 +2281,19 @@ org.apache.maven.plugins maven-shade-plugin - 3.1.0 + 3.2.0 + + + org.ow2.asm + asm + 7.0 + + + org.ow2.asm + asm-commons + 7.0 + + org.apache.maven.plugins @@ -2296,7 +2308,7 @@ org.apache.maven.plugins maven-dependency-plugin - 3.0.2 + 3.1.1 default-cli diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 341a7fdbb59b8..a10245b372d71 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -19,7 +19,6 @@ package org.apache.hive.service.cli.thrift; import java.util.Arrays; -import java.util.concurrent.ExecutorService; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -65,7 +64,7 @@ public void run() { // Server thread pool // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests String threadPoolName = "HiveServer2-HttpHandler-Pool"; - ExecutorService executorService = new ThreadPoolExecutor(minWorkerThreads, maxWorkerThreads, + ThreadPoolExecutor executorService = new ThreadPoolExecutor(minWorkerThreads, maxWorkerThreads, workerKeepAliveTime, TimeUnit.SECONDS, new SynchronousQueue(), new ThreadFactoryWithGarbageCleanup(threadPoolName)); ExecutorThreadPool threadPool = new ExecutorThreadPool(executorService); From 2b671e729250b980aa9e4ea2d483f44fa0e129cb Mon Sep 17 00:00:00 2001 From: gss2002 Date: Wed, 14 Nov 2018 13:02:13 -0800 Subject: [PATCH 042/145] =?UTF-8?q?[SPARK-25778]=20WriteAheadLogBackedBloc?= =?UTF-8?q?kRDD=20in=20YARN=20Cluster=20Mode=20Fails=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …due lack of access to tmpDir from $PWD to HDFS WriteAheadLogBackedBlockRDD usage of java.io.tmpdir will fail if $PWD resolves to a folder in HDFS and the Spark YARN Cluster job does not have the correct access to this folder in regards to the dummy folder. So this patch provides an option to set spark.streaming.receiver.blockStore.tmpdir to override java.io.tmpdir which sets $PWD from YARN Cluster mode. ## What changes were proposed in this pull request? This change provides an option to override the java.io.tmpdir option so that when $PWD is resolved in YARN Cluster mode Spark does not attempt to use this folder and instead use the folder provided with the following option: spark.streaming.receiver.blockStore.tmpdir ## How was this patch tested? Patch was manually tested on a Spark Streaming Job with Write Ahead logs in Cluster mode. Closes #22867 from gss2002/SPARK-25778. Authored-by: gss2002 Signed-off-by: Marcelo Vanzin --- .../spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 844760ab61d2e..f677c492d561f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -136,7 +136,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // this dummy directory should not already exist otherwise the WAL will try to recover // past events from the directory and throw errors. val nonExistentDirectory = new File( - System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath + System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).toURI.toString writeAheadLog = WriteAheadLogUtils.createLogForReceiver( SparkEnv.get.conf, nonExistentDirectory, hadoopConf) dataRead = writeAheadLog.read(partition.walRecordHandle) From 2977e2312d9690c9ced3c86b0ce937819e957775 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 14 Nov 2018 13:05:18 -0800 Subject: [PATCH 043/145] [SPARK-25986][BUILD] Add rules to ban throw Errors in application code ## What changes were proposed in this pull request? Add scala and java lint check rules to ban the usage of `throw new xxxErrors` and fix up all exists instance followed by https://github.com/apache/spark/pull/22989#issuecomment-437939830. See more details in https://github.com/apache/spark/pull/22969. ## How was this patch tested? Local test with lint-scala and lint-java. Closes #22989 from xuanyuanking/SPARK-25986. Authored-by: Yuanjian Li Signed-off-by: Sean Owen --- .../spark/unsafe/UnsafeAlignedOffset.java | 4 +++ .../apache/spark/memory/MemoryConsumer.java | 2 ++ .../spark/memory/TaskMemoryManager.java | 4 +++ .../unsafe/sort/UnsafeInMemorySorter.java | 2 ++ .../spark/util/random/RandomSampler.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 2 ++ .../apache/spark/executor/ExecutorSuite.scala | 2 ++ .../scheduler/TaskResultGetterSuite.scala | 2 ++ .../spark/storage/BlockManagerSuite.scala | 2 +- dev/checkstyle.xml | 13 +++++--- .../spark/streaming/kafka010/KafkaUtils.scala | 2 +- .../org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 8 ++--- .../org/apache/spark/ml/param/params.scala | 4 +-- .../spark/ml/tuning/ValidatorParams.scala | 4 +-- .../mllib/classification/NaiveBayes.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 6 ++-- .../ml/classification/ClassifierSuite.scala | 11 ++++--- .../ml/classification/NaiveBayesSuite.scala | 4 +-- .../ml/classification/OneVsRestSuite.scala | 16 +++++----- .../spark/ml/feature/VectorIndexerSuite.scala | 4 ++- .../ml/tree/impl/RandomForestSuite.scala | 6 ++-- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 ++-- .../spark/ml/tuning/CrossValidatorSuite.scala | 32 +++++++++---------- .../ml/tuning/TrainValidationSplitSuite.scala | 12 +++---- .../tuning/ValidatorParamsSuiteHelpers.scala | 3 +- .../classification/NaiveBayesSuite.scala | 2 +- .../spark/mllib/clustering/KMeansSuite.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 15 +++++---- scalastyle-config.xml | 11 +++++++ .../TungstenAggregationIterator.scala | 2 ++ .../spark/sql/FileBasedDataSourceSuite.scala | 9 +++--- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 2 +- .../spark/streaming/util/StateMap.scala | 2 +- .../spark/streaming/InputStreamsSuite.scala | 5 +-- .../spark/streaming/StateMapSuite.scala | 2 +- 39 files changed, 128 insertions(+), 87 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java index be62e40412f83..546e8780a6606 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -39,7 +39,9 @@ public static int getSize(Object object, long offset) { case 8: return (int)Platform.getLong(object, offset); default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } @@ -52,7 +54,9 @@ public static void putSize(Object object, long offset, int value) { Platform.putLong(object, offset, value); break; default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } } diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 115e1fbb79a2e..8371deca7311d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -154,7 +154,9 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..28b646ba3c951 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -194,8 +194,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } } @@ -215,8 +217,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 75690ae264838..1a9453a8b3e80 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -214,7 +214,9 @@ public boolean hasSpaceForAnotherRecord() { public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( array.getBaseObject(), diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index ea99a7e5b4847..70554f1d03067 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -49,7 +49,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = - throw new NotImplementedError("clone() is not implemented.") + throw new UnsupportedOperationException("clone() is not implemented.") } private[spark] diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index d805c67714ff8..f2d97d452ddb0 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -257,7 +257,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local[1,2]", "test") intercept[SparkException] { sc.parallelize(1 to 2).foreach { i => + // scalastyle:off throwerror throw new LinkageError() + // scalastyle:on throwerror } } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1f8a65707b2f7..32a94e60484e3 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -467,7 +467,9 @@ class FetchFailureHidingRDD( } catch { case t: Throwable => if (throwOOM) { + // scalastyle:off throwerror throw new OutOfMemoryError("OOM while handling another exception") + // scalastyle:on throwerror } else if (interrupt) { // make sure our test is setup correctly assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index f8eb8bd71c170..efb8b15cf6b4d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -265,7 +265,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local private class UndeserializableException extends Exception { private def readObject(in: ObjectInputStream): Unit = { + // scalastyle:off throwerror throw new NoClassDefFoundError() + // scalastyle:on throwerror } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32d6e8b94e1a2..cf00c1c3aad39 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -574,7 +574,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "list1", StorageLevel.MEMORY_ONLY, ClassTag.Any, - () => throw new AssertionError("attempted to compute locally")).isLeft) + () => fail("attempted to compute locally")).isLeft) } test("in-memory LRU storage") { diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 53c284888ebb0..e8859c01f2bd8 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -71,13 +71,13 @@ If you wish to turn off checking for a section of code, you can put a comment in the source before and after the section, with the following syntax: - // checkstyle:off no.XXX (such as checkstyle.off: NoFinalizer) + // checkstyle.off: XXX (such as checkstyle.off: NoFinalizer) ... // stuff that breaks the styles - // checkstyle:on + // checkstyle.on: XXX (such as checkstyle.on: NoFinalizer) --> - - + + @@ -180,5 +180,10 @@ + + + + + diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index 64b6ef6c53b6d..2516b948f6650 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -56,7 +56,7 @@ object KafkaUtils extends Logging { ): RDD[ConsumerRecord[K, V]] = { val preferredHosts = locationStrategy match { case PreferBrokers => - throw new AssertionError( + throw new IllegalArgumentException( "If you want to prefer brokers, you must provide a mapping using PreferFixed " + "A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.") case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]() diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 5824e463ca1aa..6e950f968a65d 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -106,7 +106,7 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 51495c1a74e69..1a7a5e7a52344 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -146,7 +146,7 @@ class NaiveBayes @Since("1.5.0") ( requireZeroOneBernoulliValues case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } @@ -196,7 +196,7 @@ class NaiveBayes @Since("1.5.0") ( case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } var j = 0 while (j < numFeatures) { @@ -295,7 +295,7 @@ class NaiveBayesModel private[ml] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } @Since("1.6.0") @@ -329,7 +329,7 @@ class NaiveBayesModel private[ml] ( bernoulliCalculation(features) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index e6c347ed17c15..4c50f1e3292bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -97,7 +97,7 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali case m: Matrix => JsonMatrixConverter.toJson(m) case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonEncode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } @@ -151,7 +151,7 @@ private[ml] object Param { } case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonDecode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 135828815504a..6d46ea0adcc9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -140,8 +140,8 @@ private[ml] object ValidatorParams { "value" -> compact(render(JString(relativePath))), "isJson" -> compact(render(JBool(false)))) case _: MLWritable => - throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " + - "of type: MLWritable that are not DefaultParamsWritable") + throw new UnsupportedOperationException("ValidatorParams.saveImpl does not handle" + + " parameters of type: MLWritable that are not DefaultParamsWritable") case _ => Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v), "isJson" -> compact(render(JBool(true)))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 9e8774732efe6..16ba6cabdc823 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -83,7 +83,7 @@ class NaiveBayesModel private[spark] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } @Since("1.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6e68d9684a672..9cdf1944329b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -117,7 +117,7 @@ sealed trait Vector extends Serializable { */ @Since("1.1.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index ec45e32d412a9..dff00eade620f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -73,7 +73,7 @@ object PredictorSuite { } override def copy(extra: ParamMap): MockPredictor = - throw new NotImplementedError() + throw new UnsupportedOperationException() } class MockPredictionModel(override val uid: String) @@ -82,9 +82,9 @@ object PredictorSuite { def this() = this(Identifiable.randomUID("mockpredictormodel")) override def predict(features: Vector): Double = - throw new NotImplementedError() + throw new UnsupportedOperationException() override def copy(extra: ParamMap): MockPredictionModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 87bf2be06c2be..be52d99e54d3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -117,10 +117,10 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassifier")) - override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassifier = throw new UnsupportedOperationException() override def train(dataset: Dataset[_]): MockClassificationModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() // Make methods public override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = @@ -133,11 +133,12 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassificationmodel")) - protected def predictRaw(features: Vector): Vector = throw new NotImplementedError() + protected def predictRaw(features: Vector): Vector = throw new UnsupportedOperationException() - override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassificationModel = + throw new UnsupportedOperationException() - override def numClasses: Int = throw new NotImplementedError() + override def numClasses: Int = throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 5f9ab98a2c3ce..a8c4f091b2aed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -103,7 +103,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { case Bernoulli => expectedBernoulliProbabilities(model, features) case _ => - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) } @@ -378,7 +378,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 2c3417c7e4028..519ec1720eb98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -134,8 +134,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients ~== lrModel2.coefficients relTol 1E-3) assert(lrModel1.intercept ~== lrModel2.intercept relTol 1E-3) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected model of type LogisticRegressionModel " + + s"but found ${other.getClass.getName}") } } @@ -247,8 +247,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lr.getMaxIter === lr2.getMaxIter) assert(lr.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRest expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -267,8 +267,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(classifier.getMaxIter === lr2.getMaxIter) assert(classifier.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } assert(model.labelMetadata === model2.labelMetadata) @@ -278,8 +278,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients === lrModel2.coefficients) assert(lrModel1.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail(s"Loaded OneVsRestModel expected model of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index e5675e31bbecf..fb5789f945dec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -283,7 +283,9 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { points.zip(rows.map(_(0))).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + case _ => + // should never happen + fail("Unit test has a bug in it.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..5caa5117d5752 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -417,9 +417,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) - case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") + case _ => fail("model.rootNode.split was not a CategoricalSplit") } - case _ => throw new AssertionError("model.rootNode was not an InternalNode") + case _ => fail("model.rootNode was not an InternalNode") } } @@ -444,7 +444,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => throw new AssertionError("rootNode was not an InternalNode") + case _ => fail("rootNode was not an InternalNode") } // Single group second level tree construction. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..ae9794b87b08d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -112,7 +112,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(a.rootNode, b.rootNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -133,7 +133,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(aye.rightChild, bee.rightChild) case (aye: LeafNode, bee: LeafNode) => // do nothing case _ => - throw new AssertionError("Found mismatched nodes") + fail("Found mismatched nodes") } } @@ -148,7 +148,7 @@ private[ml] object TreeTests extends SparkFunSuite { } assert(a.treeWeights === b.treeWeights) } catch { - case ex: Exception => throw new AssertionError( + case ex: Exception => fail( "checkEqual failed since the two tree ensembles were not identical") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index e6ee7220d2279..a30428ec2d283 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -190,8 +190,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -281,13 +281,13 @@ class CrossValidatorSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type OneVsRest but " + + s"found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -364,8 +364,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded internal CrossValidator expected to be" + - s" LogisticRegression but found type ${other.getClass.getName}") + fail("Loaded internal CrossValidator expected to be LogisticRegression" + + s" but found type ${other.getClass.getName}") } assert(lrcv.uid === lrcv2.uid) assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) @@ -373,12 +373,12 @@ class CrossValidatorSuite ValidatorParamsSuiteHelpers .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) case other => - throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + - " but found: " + other.map(_.getClass.getName).mkString(", ")) + fail("Loaded Pipeline expected stages (HashingTF, CrossValidator) but found: " + + other.map(_.getClass.getName).mkString(", ")) } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" CrossValidator but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type CrossValidator but found" + + s" ${other.getClass.getName}") } } @@ -433,8 +433,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getThreshold === lr2.getThreshold) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -447,8 +447,8 @@ class CrossValidatorSuite assert(lrModel.coefficients === lrModel2.coefficients) assert(lrModel.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected bestModel of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } assert(cv.avgMetrics === cv2.avgMetrics) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cd76acf9c67bc..289db336eca5d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -187,8 +187,8 @@ class TrainValidationSplitSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -264,13 +264,13 @@ class TrainValidationSplitSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type OneVsRest" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala index eae1f5adc8842..cea2f50d3470c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala @@ -47,8 +47,7 @@ object ValidatorParamsSuiteHelpers extends Assertions { val estimatorParamMap2 = Array(estimator2.extractParamMap()) compareParamMaps(estimatorParamMap, estimatorParamMap2) case other => - throw new AssertionError(s"Expected parameter of type Params but" + - s" found ${otherParam.getClass.getName}") + fail(s"Expected parameter of type Params but found ${otherParam.getClass.getName}") } case _ => assert(otherParam === v) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5ec4c15387e94..8c7d583923b32 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -71,7 +71,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 1b98250061c7a..d18cef7e264db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -349,7 +349,7 @@ object KMeansSuite extends SparkFunSuite { case (ca: DenseVector, cb: DenseVector) => assert(ca === cb) case _ => - throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + fail("checkEqual failed since the two clusters were not identical.\n") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bc59f3f4125fb..34bc303ac6079 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -607,7 +607,7 @@ object DecisionTreeSuite extends SparkFunSuite { checkEqual(a.topNode, b.topNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -628,20 +628,21 @@ object DecisionTreeSuite extends SparkFunSuite { // TODO: Check other fields besides the information gain. case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) case (None, None) => - case _ => throw new AssertionError( - s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + case _ => fail(s"Only one instance has stats defined. (a.stats: ${a.stats}, " + + s"b.stats: ${b.stats})") } (a.leftNode, b.leftNode) match { case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has leftNode defined. " + - s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + case _ => + fail("Only one instance has leftNode defined. (a.leftNode: ${a.leftNode}," + + " b.leftNode: ${b.leftNode})") } (a.rightNode, b.rightNode) match { case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has rightNode defined. " + - s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + case _ => fail("Only one instance has rightNode defined. (a.rightNode: ${a.rightNode}, " + + "b.rightNode: ${b.rightNode})") } } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 36a73e3362218..4892819ae9973 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -240,6 +240,17 @@ This file is divided into 3 sections: ]]> + + throw new \w+Error\( + + + JavaConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 72505f7fac0c6..6d849869b577a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -206,7 +206,9 @@ class TungstenAggregationIterator( buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // failed to allocate the first page + // scalastyle:off throwerror throw new SparkOutOfMemoryError("No enough memory for aggregation") + // scalastyle:on throwerror } } processRow(buffer, newInput) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 94f163708832c..64b42c32b8b1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -509,9 +509,9 @@ object TestingUDT { override def sqlType: DataType = CalendarIntervalType override def serialize(obj: IntervalData): Any = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): IntervalData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[IntervalData] = classOf[IntervalData] } @@ -521,9 +521,10 @@ object TestingUDT { private[sql] class NullUDT extends UserDefinedType[NullData] { override def sqlType: DataType = NullType - override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def serialize(obj: NullData): Any = + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): NullData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[NullData] = classOf[NullData] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e4e224df7607f..142ab6170a734 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -790,6 +790,6 @@ private case class DummySparkPlan( override val requiredChildDistribution: Seq[Distribution] = Nil, override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil ) extends SparkPlan { - override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bceaf1a9ec061..955c3e3fa6f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -614,7 +614,7 @@ class TestFileFormat extends TextBasedFileFormat { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - throw new NotImplementedError("JUST FOR TESTING") + throw new UnsupportedOperationException("JUST FOR TESTING") } override def buildReader( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index f57f07b498261..e8062dbb91e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1123,7 +1123,7 @@ class ColumnarBatchSuite extends SparkFunSuite { compareStruct(childFields, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => - throw new NotImplementedError("Not implemented " + field.dataType) + throw new UnsupportedOperationException("Not implemented " + field.dataType) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 89524cd84ff32..618c036377aee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -70,7 +70,7 @@ private[streaming] object StateMap { /** Implementation of StateMap interface representing an empty map */ private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { - throw new NotImplementedError("put() should not be called on an EmptyStateMap") + throw new UnsupportedOperationException("put() should not be called on an EmptyStateMap") } override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 1cf21e8a28033..7376741f64a12 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -31,6 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.scalatest.Assertions import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ @@ -532,7 +533,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { /** This is a server to test the network input stream */ -class TestServer(portToBind: Int = 0) extends Logging { +class TestServer(portToBind: Int = 0) extends Logging with Assertions { val queue = new ArrayBlockingQueue[String](100) @@ -592,7 +593,7 @@ class TestServer(portToBind: Int = 0) extends Logging { servingThread.start() if (!waitForStart(10000)) { stop() - throw new AssertionError("Timeout: TestServer cannot start in 10 seconds") + fail("Timeout: TestServer cannot start in 10 seconds") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 484f3733e8423..e444132d3a626 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -35,7 +35,7 @@ class StateMapSuite extends SparkFunSuite { test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] - intercept[scala.NotImplementedError] { + intercept[UnsupportedOperationException] { map.put(1, 1, 1) } assert(map.get(1) === None) From ad853c56788fd32e035369d1fe3d96aaf6c4ef16 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 14 Nov 2018 16:22:23 -0800 Subject: [PATCH 044/145] [SPARK-25956] Make Scala 2.12 as default Scala version in Spark 3.0 ## What changes were proposed in this pull request? This PR makes Spark's default Scala version as 2.12, and Scala 2.11 will be the alternative version. This implies that Scala 2.12 will be used by our CI builds including pull request builds. We'll update the Jenkins to include a new compile-only jobs for Scala 2.11 to ensure the code can be still compiled with Scala 2.11. ## How was this patch tested? existing tests Closes #22967 from dbtsai/scala2.12. Authored-by: DB Tsai Signed-off-by: Dongjoon Hyun --- assembly/pom.xml | 4 +-- common/kvstore/pom.xml | 4 +-- common/network-common/pom.xml | 4 +-- common/network-shuffle/pom.xml | 4 +-- common/network-yarn/pom.xml | 4 +-- common/sketch/pom.xml | 4 +-- common/tags/pom.xml | 4 +-- common/unsafe/pom.xml | 4 +-- core/pom.xml | 4 +-- dev/deps/spark-deps-hadoop-2.7 | 36 +++++++++---------- dev/deps/spark-deps-hadoop-3.1 | 36 +++++++++---------- docs/_config.yml | 4 +-- docs/_plugins/copy_api_dirs.rb | 2 +- docs/building-spark.md | 18 +++++----- docs/cloud-integration.md | 2 +- docs/sparkr.md | 2 +- examples/pom.xml | 4 +-- external/avro/pom.xml | 4 +-- external/docker-integration-tests/pom.xml | 4 +-- external/kafka-0-10-assembly/pom.xml | 4 +-- external/kafka-0-10-sql/pom.xml | 4 +-- external/kafka-0-10/pom.xml | 4 +-- external/kinesis-asl-assembly/pom.xml | 4 +-- external/kinesis-asl/pom.xml | 4 +-- external/spark-ganglia-lgpl/pom.xml | 4 +-- graphx/pom.xml | 4 +-- hadoop-cloud/pom.xml | 4 +-- launcher/pom.xml | 4 +-- mllib-local/pom.xml | 4 +-- mllib/pom.xml | 4 +-- pom.xml | 20 ++++++----- project/MimaBuild.scala | 2 +- project/SparkBuild.scala | 14 ++++---- python/run-tests.py | 4 +-- repl/pom.xml | 4 +-- resource-managers/kubernetes/core/pom.xml | 4 +-- .../kubernetes/integration-tests/pom.xml | 4 +-- resource-managers/mesos/pom.xml | 4 +-- resource-managers/yarn/pom.xml | 4 +-- sql/catalyst/pom.xml | 4 +-- sql/core/pom.xml | 4 +-- sql/hive-thriftserver/pom.xml | 4 +-- sql/hive/pom.xml | 4 +-- streaming/pom.xml | 4 +-- tools/pom.xml | 4 +-- 45 files changed, 138 insertions(+), 138 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index b0337e58cca71..68ebfadb668ab 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-assembly_2.11 + spark-assembly_2.12 Spark Project Assembly http://spark.apache.org/ pom diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 23a0f49206909..f042a12fda3d2 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-kvstore_2.11 + spark-kvstore_2.12 jar Spark Project Local DB http://spark.apache.org/ diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 41fcbf0589499..56d01fa0e8b3d 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-common_2.11 + spark-network-common_2.12 jar Spark Project Networking http://spark.apache.org/ diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index ff717057bb25d..a6d99813a8501 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-shuffle_2.11 + spark-network-shuffle_2.12 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a1cf761d12d8b..55cdc3140aa08 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-yarn_2.11 + spark-network-yarn_2.12 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index adbbcb1cb3040..3c3c0d2d96a1c 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sketch_2.11 + spark-sketch_2.12 jar Spark Project Sketch http://spark.apache.org/ diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f6627beabe84b..883b73a69c9de 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-tags_2.11 + spark-tags_2.12 jar Spark Project Tags http://spark.apache.org/ diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 62c493a5e1ed8..7e4b08217f1b0 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-unsafe_2.11 + spark-unsafe_2.12 jar Spark Project Unsafe http://spark.apache.org/ diff --git a/core/pom.xml b/core/pom.xml index 5c26f9a5ea3c6..36d93212ba9f9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-core_2.11 + spark-core_2.12 core diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 01691811fd3eb..c2f5755ca9925 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -22,13 +22,13 @@ avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -96,7 +96,7 @@ jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.11-2.9.6.jar +jackson-module-scala_2.12-2.9.6.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.18.1-GA.jar @@ -122,10 +122,10 @@ jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.0.jar jta-1.1.jar @@ -140,8 +140,8 @@ libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar @@ -170,19 +170,19 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index fd46f1491874a..811febf22940d 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -20,13 +20,13 @@ avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar @@ -96,7 +96,7 @@ jackson-jaxrs-json-provider-2.7.8.jar jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.11-2.9.6.jar +jackson-module-scala_2.12-2.9.6.jar janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar @@ -123,10 +123,10 @@ joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-smart-2.3.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.0.jar jta-1.1.jar @@ -155,8 +155,8 @@ libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar @@ -189,19 +189,19 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0.1.jar stax2-api-3.1.4.jar stream-2.7.0.jar diff --git a/docs/_config.yml b/docs/_config.yml index c3ef98575fa62..649d18bf72b57 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -16,8 +16,8 @@ include: # of Spark, Scala, and Mesos. SPARK_VERSION: 3.0.0-SNAPSHOT SPARK_VERSION_SHORT: 3.0.0 -SCALA_BINARY_VERSION: "2.11" -SCALA_VERSION: "2.11.12" +SCALA_BINARY_VERSION: "2.12" +SCALA_VERSION: "2.12.7" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 4d0d043a349bb..2d1a9547e3731 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.11/unidoc" + source = "../target/scala-2.12/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index 8af90db9a19dd..dfcd53c48e85c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -96,9 +96,9 @@ It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: - ./build/mvn -pl :spark-streaming_2.11 clean install + ./build/mvn -pl :spark-streaming_{{site.SCALA_BINARY_VERSION}} clean install -where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file. +where `spark-streaming_{{site.SCALA_BINARY_VERSION}}` is the `artifactId` as defined in `streaming/pom.xml` file. ## Continuous Compilation @@ -230,7 +230,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_{{site.SCALA_BINARY_VERSION}} or @@ -238,17 +238,17 @@ or ## Change Scala Version -To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.12): +To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.11): - ./dev/change-scala-version.sh 2.12 + ./dev/change-scala-version.sh 2.11 -For Maven, please enable the profile (e.g. 2.12): +For Maven, please enable the profile (e.g. 2.11): - ./build/mvn -Pscala-2.12 compile + ./build/mvn -Pscala-2.11 compile -For SBT, specify a complete scala version using (e.g. 2.12.6): +For SBT, specify a complete scala version using (e.g. 2.11.12): - ./build/sbt -Dscala.version=2.12.6 + ./build/sbt -Dscala.version=2.11.12 Otherwise, the sbt-pom-reader plugin will use the `scala.version` specified in the spark-parent pom. diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 36753f6373b55..5368e13727334 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -85,7 +85,7 @@ is set to the chosen version of Spark: ... org.apache.spark - hadoop-cloud_2.11 + hadoop-cloud_{{site.SCALA_BINARY_VERSION}} ${spark.version} ... diff --git a/docs/sparkr.md b/docs/sparkr.md index cc6bc6d14853d..acd0e77c4d71a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -133,7 +133,7 @@ specifying `--packages` with `spark-submit` or `sparkR` commands, or if initiali
    {% highlight r %} -sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +sparkR.session(sparkPackages = "org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION}}") {% endhighlight %}
    diff --git a/examples/pom.xml b/examples/pom.xml index 756c475b4748d..0636406595f6e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-examples_2.11 + spark-examples_2.12 jar Spark Project Examples http://spark.apache.org/ diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 9d8f319cc9396..ba6f20bfdbf58 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-avro_2.11 + spark-avro_2.12 avro diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index f24254b698080..b39db7540b7d2 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-docker-integration-tests_2.11 + spark-docker-integration-tests_2.12 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 4f9c3163b2408..f2dcf5d217a89 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10-assembly_2.11 + spark-streaming-kafka-0-10-assembly_2.12 jar Spark Integration for Kafka 0.10 Assembly http://spark.apache.org/ diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index efd0862fb58ee..3f1055a75076f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql-kafka-0-10_2.11 + spark-sql-kafka-0-10_2.12 sql-kafka-0-10 diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index f59f07265a0f4..d75b13da8fb70 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10_2.11 + spark-streaming-kafka-0-10_2.12 streaming-kafka-0-10 diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 0bf4c265939e7..0ce922349ea66 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl-assembly_2.11 + spark-streaming-kinesis-asl-assembly_2.12 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 0aef25329db99..7d69764b77de7 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl_2.11 + spark-streaming-kinesis-asl_2.12 jar Spark Kinesis Integration diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 35a55b70baf33..a23d255f9187c 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-ganglia-lgpl_2.11 + spark-ganglia-lgpl_2.12 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 22bc148e068a5..444568a03d6c7 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-graphx_2.11 + spark-graphx_2.12 graphx diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 3182ab15db5f5..2e5b04622cf1c 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-hadoop-cloud_2.11 + spark-hadoop-cloud_2.12 jar Spark Project Cloud Integration through Hadoop Libraries diff --git a/launcher/pom.xml b/launcher/pom.xml index b1b6126ea5934..e75e8345cd51d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-launcher_2.11 + spark-launcher_2.12 jar Spark Project Launcher http://spark.apache.org/ diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index ec5f9b0e92c8f..2eab868ac0dc8 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib-local_2.11 + spark-mllib-local_2.12 mllib-local diff --git a/mllib/pom.xml b/mllib/pom.xml index 17ddb87c4d86a..0b17345064a71 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib_2.11 + spark-mllib_2.12 mllib diff --git a/pom.xml b/pom.xml index ee1fd472a3ea7..59e3d0fa772b4 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 18 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT pom Spark Project Parent POM @@ -154,8 +154,8 @@ 3.4.1 3.2.2 - 2.11.12 - 2.11 + 2.12.7 + 2.12 1.9.13 2.9.6 1.1.7.1 @@ -1998,6 +1998,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.11 *:*_2.10 true @@ -2705,14 +2706,14 @@ - scala-2.11 + scala-2.12 - scala-2.12 + scala-2.11 - 2.12.7 - 2.12 + 2.11.12 + 2.11 @@ -2728,8 +2729,9 @@ - - *:*_2.11 + + org.jboss.netty + org.codehaus.groovy *:*_2.10 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 79e6745977e5b..10c02103aeddb 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -90,7 +90,7 @@ object MimaBuild { val organization = "org.apache.spark" val previousSparkVersion = "2.4.0" val project = projectRef.project - val fullId = "spark-" + project + "_2.11" + val fullId = "spark-" + project + "_2.12" mimaDefaultSettings ++ Seq(mimaPreviousArtifacts := Set(organization % fullId % previousSparkVersion), mimaBinaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5e034f9fe2a95..08e22fab65165 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -95,15 +95,15 @@ object SparkBuild extends PomBuild { } Option(System.getProperty("scala.version")) - .filter(_.startsWith("2.12")) + .filter(_.startsWith("2.11")) .foreach { versionString => - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } - if (System.getProperty("scala-2.12") == "") { + if (System.getProperty("scala-2.11") == "") { // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } profiles } @@ -849,10 +849,10 @@ object TestSettings { import BuildCommons._ private val scalaBinaryVersion = - if (System.getProperty("scala-2.12") == "true") { - "2.12" - } else { + if (System.getProperty("scala-2.11") == "true") { "2.11" + } else { + "2.12" } lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those diff --git a/python/run-tests.py b/python/run-tests.py index 44305741afe3e..9fd1c9b94ac6f 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -59,9 +59,7 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -# Later, add back 2.12 to this list: -# for scala in ["2.11", "2.12"]: -for scala in ["2.11"]: +for scala in ["2.11", "2.12"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") diff --git a/repl/pom.xml b/repl/pom.xml index fa015b69d45d4..c7de67e41ca94 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-repl_2.11 + spark-repl_2.12 jar Spark Project REPL http://spark.apache.org/ diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index b89ea383bf872..8d594ee8f1478 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes_2.11 + spark-kubernetes_2.12 jar Spark Project Kubernetes diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 301b6fe8eee56..17af0e03f2bbb 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests_2.12 1.3.0 1.4.0 diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 9585bdfafdcf4..7b3aad4d6ce35 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-mesos_2.11 + spark-mesos_2.12 jar Spark Project Mesos diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index e55b814be8465..d18df9955bb1f 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-yarn_2.11 + spark-yarn_2.12 jar Spark Project YARN diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 16ecebf159c1f..20cc5d03fbe52 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-catalyst_2.11 + spark-catalyst_2.12 jar Spark Project Catalyst http://spark.apache.org/ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 95e98c5444721..ac5f1fc923e7d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sql_2.11 + spark-sql_2.12 jar Spark Project SQL http://spark.apache.org/ diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 55e051c3ed1be..4a4629fae2706 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive-thriftserver_2.11 + spark-hive-thriftserver_2.12 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ef22e2abfb53e..9994689936033 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive_2.11 + spark-hive_2.12 jar Spark Project Hive http://spark.apache.org/ diff --git a/streaming/pom.xml b/streaming/pom.xml index f9a5029a8e818..1d1ea469f7d18 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-streaming_2.11 + spark-streaming_2.12 streaming diff --git a/tools/pom.xml b/tools/pom.xml index 247f5a6df4b08..6286fad403c83 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-tools_2.11 + spark-tools_2.12 tools From f6255d7b7cc4cc5d1f4fe0e5e493a1efee22f38f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Nov 2018 08:33:06 +0800 Subject: [PATCH 045/145] [MINOR][SQL] Add disable bucketedRead workaround when throw RuntimeException ## What changes were proposed in this pull request? It will throw `RuntimeException` when read from bucketed table(about 1.7G per bucket file): ![image](https://user-images.githubusercontent.com/5399861/48346889-8041ce00-e6b7-11e8-83b0-ead83fb15821.png) Default(enable bucket read): ![image](https://user-images.githubusercontent.com/5399861/48347084-2c83b480-e6b8-11e8-913a-9cafc043e9e4.png) Disable bucket read: ![image](https://user-images.githubusercontent.com/5399861/48347099-3a393a00-e6b8-11e8-94af-cb814e1ba277.png) The reason is that each bucket file is too big. a workaround is disable bucket read. This PR add this workaround to Spark. ## How was this patch tested? manual tests Closes #23014 from wangyum/anotherWorkaround. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon --- .../spark/sql/execution/vectorized/WritableColumnVector.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index b0e119d658cb4..4f5e72c1326ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -101,10 +101,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + - "vectorized reader. For parquet file format, refer to " + + "vectorized reader, or disable " + SQLConf.BUCKETING_ENABLED().key() + " if you read " + + "from bucket table. For Parquet file format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + - ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for ORC file format, " + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; From 03306a6df39c9fd6cb581401c13c4dfc6bbd632e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Nov 2018 12:30:52 +0800 Subject: [PATCH 046/145] [SPARK-26036][PYTHON] Break large tests.py files into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/tests.py` into ...: ``` pyspark ... ├── testing ... │   └── utils.py ├── tests │   ├── __init__.py │   ├── test_appsubmit.py │   ├── test_broadcast.py │   ├── test_conf.py │   ├── test_context.py │   ├── test_daemon.py │   ├── test_join.py │   ├── test_profiler.py │   ├── test_rdd.py │   ├── test_readwrite.py │   ├── test_serializers.py │   ├── test_shuffle.py │   ├── test_taskcontext.py │   ├── test_util.py │   └── test_worker.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23033 from HyukjinKwon/SPARK-26036. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 19 +- python/pyspark/ml/tests.py | 2 +- python/pyspark/sql/tests/test_appsubmit.py | 7 +- python/pyspark/sql/tests/test_arrow.py | 7 +- python/pyspark/sql/tests/test_catalog.py | 5 +- python/pyspark/sql/tests/test_column.py | 5 +- python/pyspark/sql/tests/test_conf.py | 5 +- python/pyspark/sql/tests/test_context.py | 7 +- python/pyspark/sql/tests/test_dataframe.py | 7 +- python/pyspark/sql/tests/test_datasources.py | 5 +- python/pyspark/sql/tests/test_functions.py | 5 +- python/pyspark/sql/tests/test_group.py | 5 +- python/pyspark/sql/tests/test_pandas_udf.py | 7 +- .../sql/tests/test_pandas_udf_grouped_agg.py | 7 +- .../sql/tests/test_pandas_udf_grouped_map.py | 7 +- .../sql/tests/test_pandas_udf_scalar.py | 7 +- .../sql/tests/test_pandas_udf_window.py | 7 +- python/pyspark/sql/tests/test_readwriter.py | 5 +- python/pyspark/sql/tests/test_serde.py | 5 +- python/pyspark/sql/tests/test_session.py | 7 +- python/pyspark/sql/tests/test_streaming.py | 5 +- python/pyspark/sql/tests/test_types.py | 5 +- python/pyspark/sql/tests/test_udf.py | 7 +- python/pyspark/sql/tests/test_utils.py | 5 +- python/pyspark/test_serializers.py | 90 - python/pyspark/testing/sqlutils.py | 2 +- python/pyspark/testing/utils.py | 102 + python/pyspark/tests.py | 2502 ----------------- python/pyspark/tests/__init__.py | 16 + python/pyspark/tests/test_appsubmit.py | 248 ++ python/pyspark/{ => tests}/test_broadcast.py | 24 +- python/pyspark/tests/test_conf.py | 43 + python/pyspark/tests/test_context.py | 258 ++ python/pyspark/tests/test_daemon.py | 80 + python/pyspark/tests/test_join.py | 69 + python/pyspark/tests/test_profiler.py | 112 + python/pyspark/tests/test_rdd.py | 739 +++++ python/pyspark/tests/test_readwrite.py | 499 ++++ python/pyspark/tests/test_serializers.py | 237 ++ python/pyspark/tests/test_shuffle.py | 181 ++ python/pyspark/tests/test_taskcontext.py | 161 ++ python/pyspark/tests/test_util.py | 86 + python/pyspark/tests/test_worker.py | 157 ++ 43 files changed, 3093 insertions(+), 2666 deletions(-) delete mode 100644 python/pyspark/test_serializers.py create mode 100644 python/pyspark/testing/utils.py delete mode 100644 python/pyspark/tests.py create mode 100644 python/pyspark/tests/__init__.py create mode 100644 python/pyspark/tests/test_appsubmit.py rename python/pyspark/{ => tests}/test_broadcast.py (91%) create mode 100644 python/pyspark/tests/test_conf.py create mode 100644 python/pyspark/tests/test_context.py create mode 100644 python/pyspark/tests/test_daemon.py create mode 100644 python/pyspark/tests/test_join.py create mode 100644 python/pyspark/tests/test_profiler.py create mode 100644 python/pyspark/tests/test_rdd.py create mode 100644 python/pyspark/tests/test_readwrite.py create mode 100644 python/pyspark/tests/test_serializers.py create mode 100644 python/pyspark/tests/test_shuffle.py create mode 100644 python/pyspark/tests/test_taskcontext.py create mode 100644 python/pyspark/tests/test_util.py create mode 100644 python/pyspark/tests/test_worker.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9dbe4e4f20e03..d5fcc060616f2 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -310,6 +310,7 @@ def __hash__(self): "python/(?!pyspark/(ml|mllib|sql|streaming))" ], python_test_goals=[ + # doctests "pyspark.rdd", "pyspark.context", "pyspark.conf", @@ -318,10 +319,22 @@ def __hash__(self): "pyspark.serializers", "pyspark.profiler", "pyspark.shuffle", - "pyspark.tests", - "pyspark.test_broadcast", - "pyspark.test_serializers", "pyspark.util", + # unittests + "pyspark.tests.test_appsubmit", + "pyspark.tests.test_broadcast", + "pyspark.tests.test_conf", + "pyspark.tests.test_context", + "pyspark.tests.test_daemon", + "pyspark.tests.test_join", + "pyspark.tests.test_profiler", + "pyspark.tests.test_rdd", + "pyspark.tests.test_readwrite", + "pyspark.tests.test_serializers", + "pyspark.tests.test_shuffle", + "pyspark.tests.test_taskcontext", + "pyspark.tests.test_util", + "pyspark.tests.test_worker", ] ) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 821e037af0271..2b4b7315d98c0 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -72,7 +72,7 @@ from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * -from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase +from pyspark.testing.utils import QuietTest, ReusedPySparkTestCase as PySparkTestCase ser = PickleSerializer() diff --git a/python/pyspark/sql/tests/test_appsubmit.py b/python/pyspark/sql/tests/test_appsubmit.py index 3c71151e396b9..43abcde7785d8 100644 --- a/python/pyspark/sql/tests/test_appsubmit.py +++ b/python/pyspark/sql/tests/test_appsubmit.py @@ -22,7 +22,7 @@ import py4j from pyspark import SparkContext -from pyspark.tests import SparkSubmitTests +from pyspark.tests.test_appsubmit import SparkSubmitTests class HiveSparkSubmitTests(SparkSubmitTests): @@ -91,6 +91,7 @@ def test_hivecontext(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 44f703569703a..6e75e82d58009 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -26,7 +26,7 @@ from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest from pyspark.util import _exception_message @@ -394,6 +394,7 @@ def conf(cls): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 23d25770d4b01..873405a2c6aa3 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -194,6 +194,7 @@ def test_list_columns(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index faadde9527f6f..01d4f7e223a41 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -152,6 +152,7 @@ def test_bitwise_operations(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py index f5d68a8f48851..53ac4a66f4645 100644 --- a/python/pyspark/sql/tests/test_conf.py +++ b/python/pyspark/sql/tests/test_conf.py @@ -50,6 +50,7 @@ def test_conf(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index d9d408a0b9663..918f4ad2d62f4 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -25,7 +25,7 @@ from pyspark import HiveContext, Row from pyspark.sql.types import * from pyspark.sql.window import Window -from pyspark.tests import ReusedPySparkTestCase +from pyspark.testing.utils import ReusedPySparkTestCase class HiveContextSQLTests(ReusedPySparkTestCase): @@ -258,6 +258,7 @@ def range_frame_match(): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index eba00b5687d96..908d400e00092 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -25,7 +25,7 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest class DataFrameTests(ReusedSQLTestCase): @@ -732,6 +732,7 @@ def test_query_execution_listener_on_collect_with_arrow(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index b82737855a760..5579620bc2be1 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -165,6 +165,7 @@ def test_ignore_column_of_all_nulls(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f0b59e86af178..fe6660272e323 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -273,6 +273,7 @@ def test_sort_with_nulls_order(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 076899f377598..6de1b8ea0b3ce 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -40,6 +40,7 @@ def test_aggregator(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 54a34a7dc5b94..c4b5478a7e893 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -21,7 +21,7 @@ from pyspark.sql.utils import ParseException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -211,6 +211,7 @@ def foofoo(x, y): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index bca47cc3a69bf..5383704434c85 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -21,7 +21,7 @@ from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -498,6 +498,7 @@ def test_register_vectorized_udf_basic(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index 4d443887c0ed2..bfecc071386e9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -22,7 +22,7 @@ from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -525,6 +525,7 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 394ee978dcaed..2f585a3725988 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -28,7 +28,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\ test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \ pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -802,6 +802,7 @@ def test_datasource_with_udf(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index 26e7993f1d9d9..f0e6d2696df62 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -21,7 +21,7 @@ from pyspark.sql.window import Window from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -257,6 +257,7 @@ def test_invalid_args(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 064d308b552c1..2f8712d7631f5 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -148,6 +148,7 @@ def count_bucketed_cols(names, table="pyspark_bucket"): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 5ea0636dcbb6f..8707f46b6a25a 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -133,6 +133,7 @@ def test_BinaryType_serialization(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index b81104796fda8..c6b9e0b2ca554 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -21,7 +21,7 @@ from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, Row from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.tests import PySparkTestCase +from pyspark.testing.utils import PySparkTestCase class SparkSessionTests(ReusedSQLTestCase): @@ -315,6 +315,7 @@ def test_use_custom_class_for_extensions(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index cc0cab4881dc8..4b71759f74a55 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -561,6 +561,7 @@ def collectBatch(df, id): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 3b32c58a86639..fb673f2a385ef 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -939,6 +939,7 @@ def __init__(self, **kwargs): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 630b21517712f..d2dfb52f54475 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -27,7 +27,7 @@ from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest class UDFTests(ReusedSQLTestCase): @@ -649,6 +649,7 @@ def test_udf_init_shouldnt_initialize_context(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 63a8614d2effd..5bb921da5c2f3 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -49,6 +49,7 @@ def test_capture_illegalargument_exception(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py deleted file mode 100644 index 5b43729f9ebb1..0000000000000 --- a/python/pyspark/test_serializers.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import io -import math -import struct -import sys -import unittest - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark import serializers - - -def read_int(b): - return struct.unpack("!i", b)[0] - - -def write_int(i): - return struct.pack("!i", i) - - -class SerializersTest(unittest.TestCase): - - def test_chunked_stream(self): - original_bytes = bytearray(range(100)) - for data_length in [1, 10, 100]: - for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: - dest = ByteArrayOutput() - stream_out = serializers.ChunkedStream(dest, buffer_length) - stream_out.write(original_bytes[:data_length]) - stream_out.close() - num_chunks = int(math.ceil(float(data_length) / buffer_length)) - # length for each chunk, and a final -1 at the very end - exp_size = (num_chunks + 1) * 4 + data_length - self.assertEqual(len(dest.buffer), exp_size) - dest_pos = 0 - data_pos = 0 - for chunk_idx in range(num_chunks): - chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) - if chunk_idx == num_chunks - 1: - exp_length = data_length % buffer_length - if exp_length == 0: - exp_length = buffer_length - else: - exp_length = buffer_length - self.assertEqual(chunk_length, exp_length) - dest_pos += 4 - dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] - orig_chunk = original_bytes[data_pos:data_pos + chunk_length] - self.assertEqual(dest_chunk, orig_chunk) - dest_pos += chunk_length - data_pos += chunk_length - # ends with a -1 - self.assertEqual(dest.buffer[-4:], write_int(-1)) - - -class ByteArrayOutput(object): - def __init__(self): - self.buffer = bytearray() - - def write(self, b): - self.buffer += b - - def close(self): - pass - -if __name__ == '__main__': - from pyspark.test_serializers import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 3951776554847..afc40ccf4139d 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -23,7 +23,7 @@ from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row -from pyspark.tests import ReusedPySparkTestCase +from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.util import _exception_message diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py new file mode 100644 index 0000000000000..7df0acae026f3 --- /dev/null +++ b/python/pyspark/testing/utils.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import sys +import unittest + +from pyspark import SparkContext, SparkConf + + +have_scipy = False +have_numpy = False +try: + import scipy.sparse + have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass +try: + import numpy as np + have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def read_int(b): + return struct.unpack("!i", b)[0] + + +def write_int(i): + return struct.pack("!i", i) + + +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) + + +class PySparkTestCase(unittest.TestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name) + + def tearDown(self): + self.sc.stop() + sys.path = self._old_sys_path + + +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def conf(cls): + """ + Override this in subclasses to supply a more specific conf + """ + return SparkConf() + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class ByteArrayOutput(object): + def __init__(self): + self.buffer = bytearray() + + def write(self, b): + self.buffer += b + + def close(self): + pass diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py deleted file mode 100644 index 131c51e108cad..0000000000000 --- a/python/pyspark/tests.py +++ /dev/null @@ -1,2502 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for PySpark; additional tests are implemented as doctests in -individual modules. -""" - -from array import array -from glob import glob -import os -import re -import shutil -import subprocess -import sys -import tempfile -import time -import zipfile -import random -import threading -import hashlib - -from py4j.protocol import Py4JJavaError -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - if sys.version_info[0] >= 3: - xrange = range - basestring = str - -if sys.version >= "3": - from io import StringIO -else: - from StringIO import StringIO - - -from pyspark import keyword_only -from pyspark.conf import SparkConf -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ - PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ - FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter -from pyspark import shuffle -from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import BarrierTaskContext, TaskContext - -_have_scipy = False -_have_numpy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass -try: - import numpy as np - _have_numpy = True -except: - # No NumPy, but that's okay, we'll skip those tests - pass - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -class MergerTests(unittest.TestCase): - - def setUp(self): - self.N = 1 << 12 - self.l = [i for i in xrange(self.N)] - self.data = list(zip(self.l, self.l)) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) - - def test_small_dataset(self): - m = ExternalMerger(self.agg, 1000) - m.mergeValues(self.data) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - def test_medium_dataset(self): - m = ExternalMerger(self.agg, 20) - m.mergeValues(self.data) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N)) * 3) - - def test_huge_dataset(self): - m = ExternalMerger(self.agg, 5, partitions=3) - m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.items()), - self.N * 10) - m._cleanup() - - def test_group_by_key(self): - - def gen_data(N, step): - for i in range(1, N + 1, step): - for j in range(i): - yield (i, [j]) - - def gen_gs(N, step=1): - return shuffle.GroupByKey(gen_data(N, step)) - - self.assertEqual(1, len(list(gen_gs(1)))) - self.assertEqual(2, len(list(gen_gs(2)))) - self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) - - for k, vs in gen_gs(50002, 10000): - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - ser = PickleSerializer() - l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) - for k, vs in l: - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - def test_stopiteration_is_raised(self): - - def stopit(*args, **kwargs): - raise StopIteration() - - def legit_create_combiner(x): - return [x] - - def legit_merge_value(x, y): - return x.append(y) or x - - def legit_merge_combiners(x, y): - return x.extend(y) or x - - data = [(x % 2, x) for x in range(100)] - - # wrong create combiner - m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge value - m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge combiners - m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - - -class SorterTests(unittest.TestCase): - def test_in_memory_sort(self): - l = list(range(1024)) - random.shuffle(l) - sorter = ExternalSorter(1024) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - - def test_external_sort(self): - class CustomizedSorter(ExternalSorter): - def _next_limit(self): - return self.memory_limit - l = list(range(1024)) - random.shuffle(l) - sorter = CustomizedSorter(1) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertGreater(shuffle.DiskBytesSpilled, 0) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - - def test_external_sort_in_rdd(self): - conf = SparkConf().set("spark.python.worker.memory", "1m") - sc = SparkContext(conf=conf) - l = list(range(10240)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class SerializationTestCase(unittest.TestCase): - - def test_namedtuple(self): - from collections import namedtuple - from pickle import dumps, loads - P = namedtuple("P", "x y") - p1 = P(1, 3) - p2 = loads(dumps(p1, 2)) - self.assertEqual(p1, p2) - - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) - p3 = P2(1, 3) - self.assertEqual(p1, p3) - - def test_itemgetter(self): - from operator import itemgetter - ser = CloudPickleSerializer() - d = range(10) - getter = itemgetter(1) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - getter = itemgetter(0, 3) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - def test_function_module_name(self): - ser = CloudPickleSerializer() - func = lambda x: x - func2 = ser.loads(ser.dumps(func)) - self.assertEqual(func.__module__, func2.__module__) - - def test_attrgetter(self): - from operator import attrgetter - ser = CloudPickleSerializer() - - class C(object): - def __getattr__(self, item): - return item - d = C() - getter = attrgetter("a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("a", "b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - d.e = C() - getter = attrgetter("e.a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("e.a", "e.b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - # Regression test for SPARK-3415 - def test_pickling_file_handles(self): - # to be corrected with SPARK-11160 - if not xmlrunner: - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) - - def test_func_globals(self): - - class Unpicklable(object): - def __reduce__(self): - raise Exception("not picklable") - - global exit - exit = Unpicklable() - - ser = CloudPickleSerializer() - self.assertRaises(Exception, lambda: ser.dumps(exit)) - - def foo(): - sys.exit(0) - - self.assertTrue("exit" in foo.__code__.co_names) - ser.dumps(foo) - - def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) - try: - from StringIO import StringIO - except ImportError: - from io import BytesIO as StringIO - io = StringIO() - ser.dump_stream(["abc", u"123", range(5)], io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) - ser.dump_stream(range(1000), io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) - io.close() - - def test_hash_serializer(self): - hash(NoOpSerializer()) - hash(UTF8Deserializer()) - hash(PickleSerializer()) - hash(MarshalSerializer()) - hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) - hash(AutoBatchedSerializer(MarshalSerializer())) - hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) - - -class QuietTest(object): - def __init__(self, sc): - self.log4j = sc._jvm.org.apache.log4j - - def __enter__(self): - self.old_level = self.log4j.LogManager.getRootLogger().getLevel() - self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.log4j.LogManager.getRootLogger().setLevel(self.old_level) - - -class PySparkTestCase(unittest.TestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name) - - def tearDown(self): - self.sc.stop() - sys.path = self._old_sys_path - - -class ReusedPySparkTestCase(unittest.TestCase): - - @classmethod - def conf(cls): - """ - Override this in subclasses to supply a more specific conf - """ - return SparkConf() - - @classmethod - def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - - -class CheckpointTests(ReusedPySparkTestCase): - - def setUp(self): - self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) - - def tearDown(self): - shutil.rmtree(self.checkpointDir.name) - - def test_basic_checkpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual("file:" + self.checkpointDir.name, - os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) - - def test_checkpoint_and_restore(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: [x]) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - flatMappedRDD.count() # forces a checkpoint to be computed - time.sleep(1) # 1 second - - self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), - flatMappedRDD._jrdd_deserializer) - self.assertEqual([1, 2, 3, 4], recovered.collect()) - - -class LocalCheckpointTests(ReusedPySparkTestCase): - - def test_basic_localcheckpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) - - flatMappedRDD.localCheckpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - - -class AddFileTests(PySparkTestCase): - - def test_add_py_file(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this job fails due to `userlibrary` not being on the Python path: - # disable logging in log4j temporarily - def func(x): - from userlibrary import UserClass - return UserClass().hello() - with QuietTest(self.sc): - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - - # Add the file, so the job should now succeed: - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - res = self.sc.parallelize(range(2)).map(func).first() - self.assertEqual("Hello World!", res) - - def test_add_file_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - self.sc.addFile(path) - download_path = SparkFiles.get("hello.txt") - self.assertNotEqual(path, download_path) - with open(download_path) as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - - def test_add_file_recursively_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello") - self.sc.addFile(path, True) - download_path = SparkFiles.get("hello") - self.assertNotEqual(path, download_path) - with open(download_path + "/hello.txt") as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - with open(download_path + "/sub_hello/sub_hello.txt") as test_file: - self.assertEqual("Sub Hello World!\n", test_file.readline()) - - def test_add_py_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlibrary import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - from userlibrary import UserClass - self.assertEqual("Hello World!", UserClass().hello()) - - def test_add_egg_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlib import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") - self.sc.addPyFile(path) - from userlib import UserClass - self.assertEqual("Hello World from inside a package!", UserClass().hello()) - - def test_overwrite_system_module(self): - self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - - import SimpleHTTPServer - self.assertEqual("My Server", SimpleHTTPServer.__name__) - - def func(x): - import SimpleHTTPServer - return SimpleHTTPServer.__name__ - - self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) - - -class TaskContextTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - # Allow retries even though they are normally disabled in local mode - self.sc = SparkContext('local[4, 2]', class_name) - - def test_stage_id(self): - """Test the stage ids are available and incrementing as expected.""" - rdd = self.sc.parallelize(range(10)) - stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - # Test using the constructor directly rather than the get() - stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] - self.assertEqual(stage1 + 1, stage2) - self.assertEqual(stage1 + 2, stage3) - self.assertEqual(stage2 + 1, stage3) - - def test_partition_id(self): - """Test the partition id.""" - rdd1 = self.sc.parallelize(range(10), 1) - rdd2 = self.sc.parallelize(range(10), 2) - pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() - pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() - self.assertEqual(0, pids1[0]) - self.assertEqual(0, pids1[9]) - self.assertEqual(0, pids2[0]) - self.assertEqual(1, pids2[9]) - - def test_attempt_number(self): - """Verify the attempt numbers are correctly reported.""" - rdd = self.sc.parallelize(range(10)) - # Verify a simple job with no failures - attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() - map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) - - def fail_on_first(x): - """Fail on the first attempt so we get a positive attempt number""" - tc = TaskContext.get() - attempt_number = tc.attemptNumber() - partition_id = tc.partitionId() - attempt_id = tc.taskAttemptId() - if attempt_number == 0 and partition_id == 0: - raise Exception("Failing on first attempt") - else: - return [x, partition_id, attempt_number, attempt_id] - result = rdd.map(fail_on_first).collect() - # We should re-submit the first partition to it but other partitions should be attempt 0 - self.assertEqual([0, 0, 1], result[0][0:3]) - self.assertEqual([9, 3, 0], result[9][0:3]) - first_partition = filter(lambda x: x[1] == 0, result) - map(lambda x: self.assertEqual(1, x[2]), first_partition) - other_partitions = filter(lambda x: x[1] != 0, result) - map(lambda x: self.assertEqual(0, x[2]), other_partitions) - # The task attempt id should be different - self.assertTrue(result[0][3] != result[9][3]) - - def test_tc_on_driver(self): - """Verify that getting the TaskContext on the driver returns None.""" - tc = TaskContext.get() - self.assertTrue(tc is None) - - def test_get_local_property(self): - """Verify that local properties set on the driver are available in TaskContext.""" - key = "testkey" - value = "testvalue" - self.sc.setLocalProperty(key, value) - try: - rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] - self.assertEqual(prop1, value) - prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] - self.assertTrue(prop2 is None) - finally: - self.sc.setLocalProperty(key, None) - - def test_barrier(self): - """ - Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks - within a stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - tc.barrier() - return time.time() - - times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() - self.assertTrue(max(times) - min(times) < 1) - - def test_barrier_with_python_worker_reuse(self): - """ - Verify that BarrierTaskContext.barrier() with reused python worker. - """ - self.sc._conf.set("spark.python.work.reuse", "true") - rdd = self.sc.parallelize(range(4), 4) - # start a normal job first to start all worker - result = rdd.map(lambda x: x ** 2).collect() - self.assertEqual([0, 1, 4, 9], result) - # make sure `spark.python.work.reuse=true` - self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") - - # worker will be reused in this barrier job - self.test_barrier() - - def test_barrier_infos(self): - """ - Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the - barrier stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() - .getTaskInfos()).collect() - self.assertTrue(len(taskInfos) == 4) - self.assertTrue(len(taskInfos[0]) == 4) - - -class RDDTests(ReusedPySparkTestCase): - - def test_range(self): - self.assertEqual(self.sc.range(1, 1).count(), 0) - self.assertEqual(self.sc.range(1, 0, -1).count(), 1) - self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) - - def test_id(self): - rdd = self.sc.parallelize(range(10)) - id = rdd.id() - self.assertEqual(id, rdd.id()) - rdd2 = rdd.map(str).filter(bool) - id2 = rdd2.id() - self.assertEqual(id + 1, id2) - self.assertEqual(id2, rdd2.id()) - - def test_empty_rdd(self): - rdd = self.sc.emptyRDD() - self.assertTrue(rdd.isEmpty()) - - def test_sum(self): - self.assertEqual(0, self.sc.emptyRDD().sum()) - self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) - - def test_to_localiterator(self): - from time import sleep - rdd = self.sc.parallelize([1, 2, 3]) - it = rdd.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it)) - - rdd2 = rdd.repartition(1000) - it2 = rdd2.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it2)) - - def test_save_as_textfile_with_unicode(self): - # Regression test for SPARK-970 - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode("utf-8")) - - def test_save_as_textfile_with_utf8(self): - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x.encode("utf-8")]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode('utf8')) - - def test_transforming_cartesian_result(self): - # Regression test for SPARK-1034 - rdd1 = self.sc.parallelize([1, 2]) - rdd2 = self.sc.parallelize([3, 4]) - cart = rdd1.cartesian(rdd2) - result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() - - def test_transforming_pickle_file(self): - # Regression test for SPARK-2601 - data = self.sc.parallelize([u"Hello", u"World!"]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsPickleFile(tempFile.name) - pickled_file = self.sc.pickleFile(tempFile.name) - pickled_file.map(lambda x: x).collect() - - def test_cartesian_on_textfile(self): - # Regression test for - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - a = self.sc.textFile(path) - result = a.cartesian(a).collect() - (x, y) = result[0] - self.assertEqual(u"Hello World!", x.strip()) - self.assertEqual(u"Hello World!", y.strip()) - - def test_cartesian_chaining(self): - # Tests for SPARK-16589 - rdd = self.sc.parallelize(range(10), 2) - self.assertSetEqual( - set(rdd.cartesian(rdd).cartesian(rdd).collect()), - set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.cartesian(rdd)).collect()), - set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.zip(rdd)).collect()), - set([(x, (y, y)) for x in range(10) for y in range(10)]) - ) - - def test_zip_chaining(self): - # Tests for SPARK-21985 - rdd = self.sc.parallelize('abc', 2) - self.assertSetEqual( - set(rdd.zip(rdd).zip(rdd).collect()), - set([((x, x), x) for x in 'abc']) - ) - self.assertSetEqual( - set(rdd.zip(rdd.zip(rdd)).collect()), - set([(x, (x, x)) for x in 'abc']) - ) - - def test_deleting_input_files(self): - # Regression test for SPARK-1025 - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - def test_sampling_default_seed(self): - # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(xrange(1000), 1) - subset = data.takeSample(False, 10) - self.assertEqual(len(subset), 10) - - def test_aggregate_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregate and treeAggregate to build dict - # representing a counter of ints - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - # Show that single or multiple partitions work - data1 = self.sc.range(10, numSlices=1) - data2 = self.sc.range(10, numSlices=2) - - def seqOp(x, y): - x[y] += 1 - return x - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) - counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) - counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - - ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) - self.assertEqual(counts1, ground_truth) - self.assertEqual(counts2, ground_truth) - self.assertEqual(counts3, ground_truth) - self.assertEqual(counts4, ground_truth) - - def test_aggregate_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that - # contains lists of all values for each key in the original RDD - - # list(range(...)) for Python 3.x compatibility (can't use * operator - # on a range object) - # list(zip(...)) for Python 3.x compatibility (want to parallelize a - # collection, not a zip object) - tuples = list(zip(list(range(10))*2, [1]*20)) - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def seqOp(x, y): - x.append(y) - return x - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.aggregateByKey([], seqOp, comboOp).collect() - values2 = data2.aggregateByKey([], seqOp, comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - ground_truth = [(i, [1]*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_fold_mutable_zero_value(self): - # Test for SPARK-9021; uses fold to merge an RDD of dict counters into - # a single dict - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - counts1 = defaultdict(int, dict((i, 1) for i in range(10))) - counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) - counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) - counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) - all_counts = [counts1, counts2, counts3, counts4] - # Show that single or multiple partitions work - data1 = self.sc.parallelize(all_counts, 1) - data2 = self.sc.parallelize(all_counts, 2) - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - fold1 = data1.fold(defaultdict(int), comboOp) - fold2 = data2.fold(defaultdict(int), comboOp) - - ground_truth = defaultdict(int) - for counts in all_counts: - for key, val in counts.items(): - ground_truth[key] += val - self.assertEqual(fold1, ground_truth) - self.assertEqual(fold2, ground_truth) - - def test_fold_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains - # lists of all values for each key in the original RDD - - tuples = [(i, range(i)) for i in range(10)]*2 - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.foldByKey([], comboOp).collect() - values2 = data2.foldByKey([], comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - # list(range(...)) for Python 3.x compatibility - ground_truth = [(i, list(range(i))*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_aggregate_by_key(self): - data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) - - def seqOp(x, y): - x.add(y) - return x - - def combOp(x, y): - x |= y - return x - - sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) - self.assertEqual(3, len(sets)) - self.assertEqual(set([1]), sets[1]) - self.assertEqual(set([2]), sets[3]) - self.assertEqual(set([1, 3]), sets[5]) - - def test_itemgetter(self): - rdd = self.sc.parallelize([range(10)]) - from operator import itemgetter - self.assertEqual([1], rdd.map(itemgetter(1)).collect()) - self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) - - def test_namedtuple_in_rdd(self): - from collections import namedtuple - Person = namedtuple("Person", "id firstName lastName") - jon = Person(1, "Jon", "Doe") - jane = Person(2, "Jane", "Doe") - theDoes = self.sc.parallelize([jon, jane]) - self.assertEqual([jon, jane], theDoes.collect()) - - def test_large_broadcast(self): - N = 10000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 27MB - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - - def test_unpersist(self): - N = 1000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 3MB - bdata.unpersist() - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - bdata.destroy() - try: - self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - except Exception as e: - pass - else: - raise Exception("job should fail after destroy the broadcast") - - def test_multiple_broadcasts(self): - N = 1 << 21 - b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = list(range(1 << 15)) - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - def test_multithread_broadcast_pickle(self): - import threading - - b1 = self.sc.broadcast(list(range(3))) - b2 = self.sc.broadcast(list(range(3))) - - def f1(): - return b1.value - - def f2(): - return b2.value - - funcs_num_pickled = {f1: None, f2: None} - - def do_pickle(f, sc): - command = (f, None, sc.serializer, sc.serializer) - ser = CloudPickleSerializer() - ser.dumps(command) - - def process_vars(sc): - broadcast_vars = list(sc._pickled_broadcast_vars) - num_pickled = len(broadcast_vars) - sc._pickled_broadcast_vars.clear() - return num_pickled - - def run(f, sc): - do_pickle(f, sc) - funcs_num_pickled[f] = process_vars(sc) - - # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage - do_pickle(f1, self.sc) - - # run all for f2, should only add/count/clear b2 from worker thread local storage - t = threading.Thread(target=run, args=(f2, self.sc)) - t.start() - t.join() - - # count number of vars pickled in main thread, only b1 should be counted and cleared - funcs_num_pickled[f1] = process_vars(self.sc) - - self.assertEqual(funcs_num_pickled[f1], 1) - self.assertEqual(funcs_num_pickled[f2], 1) - self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) - - def test_large_closure(self): - N = 200000 - data = [float(i) for i in xrange(N)] - rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEqual(N, rdd.first()) - # regression test for SPARK-6886 - self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) - - def test_zip_with_different_serializers(self): - a = self.sc.parallelize(range(5)) - b = self.sc.parallelize(range(100, 105)) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) - b = b._reserialize(MarshalSerializer()) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - # regression test for SPARK-4841 - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - t = self.sc.textFile(path) - cnt = t.count() - self.assertEqual(cnt, t.zip(t).count()) - rdd = t.map(str) - self.assertEqual(cnt, t.zip(rdd).count()) - # regression test for bug in _reserializer() - self.assertEqual(cnt, t.zip(rdd).count()) - - def test_zip_with_different_object_sizes(self): - # regress test for SPARK-5973 - a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) - self.assertEqual(10000, a.zip(b).count()) - - def test_zip_with_different_number_of_items(self): - a = self.sc.parallelize(range(5), 2) - # different number of partitions - b = self.sc.parallelize(range(100, 106), 3) - self.assertRaises(ValueError, lambda: a.zip(b)) - with QuietTest(self.sc): - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) - - def test_count_approx_distinct(self): - rdd = self.sc.parallelize(xrange(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) - - rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) - self.assertTrue(18 < rdd.countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) - - self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) - - def test_histogram(self): - # empty - rdd = self.sc.parallelize([]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - self.assertRaises(ValueError, lambda: rdd.histogram(1)) - - # out of range - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) - - # in range with one bucket - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual([4], rdd.histogram([0, 10])[1]) - self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) - - # in range with one bucket exact match - self.assertEqual([4], rdd.histogram([1, 4])[1]) - - # out of range with two buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) - - # out of range with two uneven buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - - # in range with two buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two bucket and None - rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two uneven buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) - - # mixed range with two uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) - - # mixed range with four uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # mixed range with uneven buckets and NaN - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, - 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # out of range with infinite buckets - rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) - - # invalid buckets - self.assertRaises(ValueError, lambda: rdd.histogram([])) - self.assertRaises(ValueError, lambda: rdd.histogram([1])) - self.assertRaises(ValueError, lambda: rdd.histogram(0)) - self.assertRaises(TypeError, lambda: rdd.histogram({})) - - # without buckets - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 4], [4]), rdd.histogram(1)) - - # without buckets single element - rdd = self.sc.parallelize([1]) - self.assertEqual(([1, 1], [1]), rdd.histogram(1)) - - # without bucket no range - rdd = self.sc.parallelize([1] * 4) - self.assertEqual(([1, 1], [4]), rdd.histogram(1)) - - # without buckets basic two - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) - - # without buckets with more requested than elements - rdd = self.sc.parallelize([1, 2]) - buckets = [1 + 0.2 * i for i in range(6)] - hist = [1, 0, 0, 0, 1] - self.assertEqual((buckets, hist), rdd.histogram(5)) - - # invalid RDDs - rdd = self.sc.parallelize([1, float('inf')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - rdd = self.sc.parallelize([float('nan')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - - # string - rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - def test_repartitionAndSortWithinPartitions_asc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) - - def test_repartitionAndSortWithinPartitions_desc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) - self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) - - def test_repartition_no_skewed(self): - num_partitions = 20 - a = self.sc.parallelize(range(int(1000)), 2) - l = a.repartition(num_partitions).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - l = a.coalesce(num_partitions, True).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - - def test_repartition_on_textfile(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - rdd = self.sc.textFile(path) - result = rdd.repartition(1).collect() - self.assertEqual(u"Hello World!", result[0]) - - def test_distinct(self): - rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEqual(rdd.getNumPartitions(), 10) - self.assertEqual(rdd.distinct().count(), 3) - result = rdd.distinct(5) - self.assertEqual(result.getNumPartitions(), 5) - self.assertEqual(result.count(), 3) - - def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "1m") - N = 200001 - kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) - gkv = kv.groupByKey().cache() - self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda kv: kv[0] == 1) - self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) - self.assertEqual([(N // 3, N // 3)], - filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) - result = filtered.collect()[0][1] - self.assertEqual(N // 3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) - - def test_sort_on_empty_rdd(self): - self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) - - def test_sample(self): - rdd = self.sc.parallelize(range(0, 100), 4) - wo = rdd.sample(False, 0.1, 2).collect() - wo_dup = rdd.sample(False, 0.1, 2).collect() - self.assertSetEqual(set(wo), set(wo_dup)) - wr = rdd.sample(True, 0.2, 5).collect() - wr_dup = rdd.sample(True, 0.2, 5).collect() - self.assertSetEqual(set(wr), set(wr_dup)) - wo_s10 = rdd.sample(False, 0.3, 10).collect() - wo_s20 = rdd.sample(False, 0.3, 20).collect() - self.assertNotEqual(set(wo_s10), set(wo_s20)) - wr_s11 = rdd.sample(True, 0.4, 11).collect() - wr_s21 = rdd.sample(True, 0.4, 21).collect() - self.assertNotEqual(set(wr_s11), set(wr_s21)) - - def test_null_in_rdd(self): - jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) - rdd = RDD(jrdd, self.sc, UTF8Deserializer()) - self.assertEqual([u"a", None, u"b"], rdd.collect()) - rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual([b"a", None, b"b"], rdd.collect()) - - def test_multiple_python_java_RDD_conversions(self): - # Regression test for SPARK-5361 - data = [ - (u'1', {u'director': u'David Lean'}), - (u'2', {u'director': u'Andrew Dominik'}) - ] - data_rdd = self.sc.parallelize(data) - data_java_rdd = data_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - # conversion between python and java RDD threw exceptions - data_java_rdd = converted_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - def test_narrow_dependency_in_join(self): - rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) - parted = rdd.partitionBy(2) - self.assertEqual(2, parted.union(parted).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - - tracker = self.sc.statusTracker() - - self.sc.setJobGroup("test1", "test", True) - d = sorted(parted.join(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test1")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test2", "test", True) - d = sorted(parted.join(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test2")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test3", "test", True) - d = sorted(parted.cogroup(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test3")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test4", "test", True) - d = sorted(parted.cogroup(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test4")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - # Regression test for SPARK-6294 - def test_take_on_jrdd(self): - rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) - rdd._jrdd.first() - - def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): - # Regression test for SPARK-5969 - seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence - rdd = self.sc.parallelize(seq) - for ascending in [True, False]: - sort = rdd.sortByKey(ascending=ascending, numPartitions=5) - self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) - sizes = sort.glom().map(len).collect() - for size in sizes: - self.assertGreater(size, 0) - - def test_pipe_functions(self): - data = ['1', '2', '3'] - rdd = self.sc.parallelize(data) - with QuietTest(self.sc): - self.assertEqual([], rdd.pipe('cc').collect()) - self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) - result = rdd.pipe('cat').collect() - result.sort() - for x, y in zip(data, result): - self.assertEqual(x, y) - self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) - self.assertEqual([], rdd.pipe('grep 4').collect()) - - def test_pipe_unicode(self): - # Regression test for SPARK-20947 - data = [u'\u6d4b\u8bd5', '1'] - rdd = self.sc.parallelize(data) - result = rdd.pipe('cat').collect() - self.assertEqual(data, result) - - def test_stopiteration_in_user_code(self): - - def stopit(*x): - raise StopIteration() - - seq_rdd = self.sc.parallelize(range(10)) - keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - msg = "Caught StopIteration thrown from user's code; failing the task" - - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, - seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - - # these methods call the user function both in the driver and in the executor - # the exception raised is different according to where the StopIteration happens - # RuntimeError is raised if in the driver - # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, lambda *x: 1, stopit) - - -class ProfilerTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, conf=conf) - - def test_profiler(self): - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - id, profiler, _ = profilers[0] - stats = profiler.stats() - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - old_stdout = sys.stdout - sys.stdout = io = StringIO() - self.sc.show_profiles() - self.assertTrue("heavy_foo" in io.getvalue()) - sys.stdout = old_stdout - - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - def test_custom_profiler(self): - class TestCustomProfiler(BasicProfiler): - def show(self, id): - self.result = "Custom formatting" - - self.sc.profiler_collector.profiler_cls = TestCustomProfiler - - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - _, profiler, _ = profilers[0] - self.assertTrue(isinstance(profiler, TestCustomProfiler)) - - self.sc.show_profiles() - self.assertEqual("Custom formatting", profiler.result) - - def do_computation(self): - def heavy_foo(x): - for i in range(1 << 18): - x = 1 - - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - - -class ProfilerTests2(unittest.TestCase): - def test_profiler_disabled(self): - sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) - try: - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.show_profiles()) - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.dump_profiles("/tmp/abc")) - finally: - sc.stop() - - -class InputFormatTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", - "org.apache.hadoop.io.DoubleWritable", - "org.apache.hadoop.io.Text").collect()) - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.assertEqual(doubles, ed) - - bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) - ebs = [(1, bytearray('aa', 'utf-8')), - (1, bytearray('aa', 'utf-8')), - (2, bytearray('aa', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (3, bytearray('cc', 'utf-8'))] - self.assertEqual(bytes, ebs) - - text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", - "org.apache.hadoop.io.Text", - "org.apache.hadoop.io.Text").collect()) - et = [(u'1', u'aa'), - (u'1', u'aa'), - (u'2', u'aa'), - (u'2', u'bb'), - (u'2', u'bb'), - (u'3', u'cc')] - self.assertEqual(text, et) - - bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.assertEqual(bools, eb) - - nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.assertEqual(nulls, en) - - maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - for v in maps: - self.assertTrue(v in em) - - # arrays get pickled to tuples by default - tuples = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable").collect()) - et = [(1, ()), - (2, (3.0, 4.0, 5.0)), - (3, (4.0, 5.0, 6.0))] - self.assertEqual(tuples, et) - - # with custom converters, primitive arrays can stay as arrays - arrays = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - ea = [(1, array('d')), - (2, array('d', [3.0, 4.0, 5.0])), - (3, array('d', [4.0, 5.0, 6.0]))] - self.assertEqual(arrays, ea) - - clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable").collect()) - cname = u'org.apache.spark.api.python.TestWritable' - ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), - (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), - (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), - (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), - (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] - self.assertEqual(clazz, ec) - - unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - ).collect()) - self.assertEqual(unbatched_clazz, ec) - - def test_oldhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=oldconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=newconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newolderror(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.sequenceFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.NotValidWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - maps = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - keyConverter="org.apache.spark.api.python.TestInputKeyConverter", - valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) - em = [(u'\x01', []), - (u'\x01', [3.0]), - (u'\x02', [1.0]), - (u'\x02', [1.0]), - (u'\x03', [2.0])] - self.assertEqual(maps, em) - - def test_binary_files(self): - path = os.path.join(self.tempdir.name, "binaryfiles") - os.mkdir(path) - data = b"short binary data" - with open(os.path.join(path, "part-0000"), 'wb') as f: - f.write(data) - [(p, d)] = self.sc.binaryFiles(path).collect() - self.assertTrue(p.endswith("part-0000")) - self.assertEqual(d, data) - - def test_binary_records(self): - path = os.path.join(self.tempdir.name, "binaryrecords") - os.mkdir(path) - with open(os.path.join(path, "part-0000"), 'w') as f: - for i in range(100): - f.write('%04d' % i) - result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(list(range(100)), result) - - -class OutputFormatTests(ReusedPySparkTestCase): - - def setUp(self): - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - - def tearDown(self): - shutil.rmtree(self.tempdir.name, ignore_errors=True) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") - ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) - self.assertEqual(ints, ei) - - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") - doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) - self.assertEqual(doubles, ed) - - ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] - self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") - bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) - self.assertEqual(bytes, ebs) - - et = [(u'1', u'aa'), - (u'2', u'bb'), - (u'3', u'cc')] - self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") - text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) - self.assertEqual(text, et) - - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") - bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) - self.assertEqual(bools, eb) - - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") - nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) - self.assertEqual(nulls, en) - - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() - for v in maps: - self.assertTrue(v, em) - - def test_oldhadoop(self): - basepath = self.tempdir.name - dict_data = [(1, {}), - (1, {"row1": 1.0}), - (2, {"row2": 2.0})] - self.sc.parallelize(dict_data).saveAsHadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable") - result = self.sc.hadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - for v in result: - self.assertTrue(v, dict_data) - - conf = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" - } - self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} - result = self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect() - for v in result: - self.assertTrue(v, dict_data) - - def test_newhadoop(self): - basepath = self.tempdir.name - data = [(1, ""), - (1, "a"), - (2, "bcdf")] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - self.assertEqual(result, data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=input_conf).collect()) - self.assertEqual(new_dataset, data) - - @unittest.skipIf(sys.version >= "3", "serialize of array") - def test_newhadoop_with_array(self): - basepath = self.tempdir.name - # use custom ArrayWritable types and converters to handle arrays - array_data = [(1, array('d')), - (1, array('d', [1.0, 2.0, 3.0])), - (2, array('d', [3.0, 4.0, 5.0]))] - self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - self.assertEqual(result, array_data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( - conf, - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", - conf=input_conf).collect()) - self.assertEqual(new_dataset, array_data) - - def test_newolderror(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/newolderror/saveAsHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/newolderror/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/badinputs/saveAsHadoopFile/", - "org.apache.hadoop.mapred.NotValidOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/badinputs/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - data = [(1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/converters/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", - valueConverter="org.apache.spark.api.python.TestOutputValueConverter") - converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) - expected = [(u'1', 3.0), - (u'2', 1.0), - (u'3', 2.0)] - self.assertEqual(converted, expected) - - def test_reserialization(self): - basepath = self.tempdir.name - x = range(1, 5) - y = range(1001, 1005) - data = list(zip(x, y)) - rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) - rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") - result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) - self.assertEqual(result1, data) - - rdd.saveAsHadoopFile( - basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") - result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) - self.assertEqual(result2, data) - - rdd.saveAsNewAPIHadoopFile( - basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) - self.assertEqual(result3, data) - - conf4 = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} - rdd.saveAsHadoopDataset(conf4) - result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) - self.assertEqual(result4, data) - - conf5 = {"mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" - } - rdd.saveAsNewAPIHadoopDataset(conf5) - result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) - self.assertEqual(result5, data) - - def test_malformed_RDD(self): - basepath = self.tempdir.name - # non-batch-serialized RDD[[(K, V)]] should be rejected - data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] - rdd = self.sc.parallelize(data, len(data)) - self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( - basepath + "/malformed/sequence")) - - -class DaemonTests(unittest.TestCase): - def connect(self, port): - from socket import socket, AF_INET, SOCK_STREAM - sock = socket(AF_INET, SOCK_STREAM) - sock.connect(('127.0.0.1', port)) - # send a split index of -1 to shutdown the worker - sock.send(b"\xFF\xFF\xFF\xFF") - sock.close() - return True - - def do_termination_test(self, terminator): - from subprocess import Popen, PIPE - from errno import ECONNREFUSED - - # start daemon - daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") - daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) - - # read the port number - port = read_int(daemon.stdout) - - # daemon should accept connections - self.assertTrue(self.connect(port)) - - # request shutdown - terminator(daemon) - time.sleep(1) - - # daemon should no longer accept connections - try: - self.connect(port) - except EnvironmentError as exception: - self.assertEqual(exception.errno, ECONNREFUSED) - else: - self.fail("Expected EnvironmentError to be raised") - - def test_termination_stdin(self): - """Ensure that daemon and workers terminate when stdin is closed.""" - self.do_termination_test(lambda daemon: daemon.stdin.close()) - - def test_termination_sigterm(self): - """Ensure that daemon and workers terminate on SIGTERM.""" - from signal import SIGTERM - self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) - - -class WorkerTests(ReusedPySparkTestCase): - def test_cancel_task(self): - temp = tempfile.NamedTemporaryFile(delete=True) - temp.close() - path = temp.name - - def sleep(x): - import os - import time - with open(path, 'w') as f: - f.write("%d %d" % (os.getppid(), os.getpid())) - time.sleep(100) - - # start job in background thread - def run(): - try: - self.sc.parallelize(range(1), 1).foreach(sleep) - except Exception: - pass - import threading - t = threading.Thread(target=run) - t.daemon = True - t.start() - - daemon_pid, worker_pid = 0, 0 - while True: - if os.path.exists(path): - with open(path) as f: - data = f.read().split(' ') - daemon_pid, worker_pid = map(int, data) - break - time.sleep(0.1) - - # cancel jobs - self.sc.cancelAllJobs() - t.join() - - for i in range(50): - try: - os.kill(worker_pid, 0) - time.sleep(0.1) - except OSError: - break # worker was killed - else: - self.fail("worker has not been killed after 5 seconds") - - try: - os.kill(daemon_pid, 0) - except OSError: - self.fail("daemon had been killed") - - # run a normal job - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_exception(self): - def raise_exception(_): - raise Exception() - rdd = self.sc.parallelize(xrange(100), 1) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_jvm_exception(self): - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name, 1) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_accumulator_when_reuse_worker(self): - from pyspark.accumulators import INT_ACCUMULATOR_PARAM - acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) - self.assertEqual(sum(range(100)), acc1.value) - - acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) - self.assertEqual(sum(range(100)), acc2.value) - self.assertEqual(sum(range(100)), acc1.value) - - def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(xrange(100000), 1) - self.assertEqual(0, rdd.first()) - - def count(): - try: - rdd.count() - except Exception: - pass - - t = threading.Thread(target=count) - t.daemon = True - t.start() - t.join(5) - self.assertTrue(not t.isAlive()) - self.assertEqual(100000, rdd.count()) - - def test_with_different_versions_of_python(self): - rdd = self.sc.parallelize(range(10)) - rdd.count() - version = self.sc.pythonVer - self.sc.pythonVer = "2.0" - try: - with QuietTest(self.sc): - self.assertRaises(Py4JJavaError, lambda: rdd.count()) - finally: - self.sc.pythonVer = version - - -class SparkSubmitTests(unittest.TestCase): - - def setUp(self): - self.programDir = tempfile.mkdtemp() - tmp_dir = tempfile.gettempdir() - self.sparkSubmit = [ - os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), - "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - ] - - def tearDown(self): - shutil.rmtree(self.programDir) - - def createTempFile(self, name, content, dir=None): - """ - Create a temp file with the given name and content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name) - else: - os.makedirs(os.path.join(self.programDir, dir)) - path = os.path.join(self.programDir, dir, name) - with open(path, "w") as f: - f.write(content) - return path - - def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): - """ - Create a zip archive containing a file with the given content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name + ext) - else: - path = os.path.join(self.programDir, dir, zip_name + ext) - zip = zipfile.ZipFile(path, 'w') - zip.writestr(name, content) - zip.close() - return path - - def create_spark_package(self, artifact_name): - group_id, artifact_id, version = artifact_name.split(":") - self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" - | - | - | 4.0.0 - | %s - | %s - | %s - | - """ % (group_id, artifact_id, version)).lstrip(), - os.path.join(group_id, artifact_id, version)) - self.createFileInZip("%s.py" % artifact_id, """ - |def myfunc(x): - | return x + 1 - """, ".jar", os.path.join(group_id, artifact_id, version), - "%s-%s" % (artifact_id, version)) - - def test_single_script(self): - """Submit and test a single script file""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_script_with_local_functions(self): - """Submit and test a single script file calling a global function""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 3 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out.decode('utf-8')) - - def test_module_dependency(self): - """Submit and test a script with a dependency on another module""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_module_dependency_on_cluster(self): - """Submit and test a script with a dependency on another module on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", - "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency(self): - """Submit and test a script with a dependency on a Spark Package""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency_on_cluster(self): - """Submit and test a script with a dependency on a Spark Package on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", - script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_single_script_on_cluster(self): - """Submit and test a single script on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 2 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - # this will fail if you have different spark.executor.memory - # in conf/spark-defaults.conf - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_user_configuration(self): - """Make sure user configuration is respected (SPARK-19307)""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkConf, SparkContext - | - |conf = SparkConf().set("spark.test_config", "1") - |sc = SparkContext(conf = conf) - |try: - | if sc._conf.get("spark.test_config") != "1": - | raise Exception("Cannot find spark.test_config in SparkContext's conf.") - |finally: - | sc.stop() - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local", script], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) - - -class ContextTests(unittest.TestCase): - - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - - def test_get_or_create(self): - with SparkContext.getOrCreate() as sc: - self.assertTrue(SparkContext.getOrCreate() is sc) - - def test_parallelize_eager_cleanup(self): - with SparkContext() as sc: - temp_files = os.listdir(sc._temp_dir) - rdd = sc.parallelize([0, 1, 2]) - post_parallalize_temp_files = os.listdir(sc._temp_dir) - self.assertEqual(temp_files, post_parallalize_temp_files) - - def test_set_conf(self): - # This is for an internal use case. When there is an existing SparkContext, - # SparkSession's builder needs to set configs into SparkContext's conf. - sc = SparkContext() - sc._conf.set("spark.test.SPARK16224", "SPARK16224") - self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") - sc.stop() - - def test_stop(self): - sc = SparkContext() - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_exception(self): - try: - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - raise Exception() - except: - pass - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_stop(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_progress_api(self): - with SparkContext() as sc: - sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - - def run(): - try: - rdd.count() - except Exception: - pass - t = threading.Thread(target=run) - t.daemon = True - t.start() - # wait for scheduler to start - time.sleep(1) - - tracker = sc.statusTracker() - jobIds = tracker.getJobIdsForGroup('test_progress_api') - self.assertEqual(1, len(jobIds)) - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual(1, len(job.stageIds)) - stage = tracker.getStageInfo(job.stageIds[0]) - self.assertEqual(rdd.getNumPartitions(), stage.numTasks) - - sc.cancelAllJobs() - t.join() - # wait for event listener to update the status - time.sleep(1) - - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual('FAILED', job.status) - self.assertEqual([], tracker.getActiveJobsIds()) - self.assertEqual([], tracker.getActiveStageIds()) - - sc.stop() - - def test_startTime(self): - with SparkContext() as sc: - self.assertGreater(sc.startTime, 0) - - -class ConfTests(unittest.TestCase): - def test_memory_conf(self): - memoryList = ["1T", "1G", "1M", "1024K"] - for memory in memoryList: - sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) - l = list(range(1024)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class KeywordOnlyTests(unittest.TestCase): - class Wrapped(object): - @keyword_only - def set(self, x=None, y=None): - if "x" in self._input_kwargs: - self._x = self._input_kwargs["x"] - if "y" in self._input_kwargs: - self._y = self._input_kwargs["y"] - return x, y - - def test_keywords(self): - w = self.Wrapped() - x, y = w.set(y=1) - self.assertEqual(y, 1) - self.assertEqual(y, w._y) - self.assertIsNone(x) - self.assertFalse(hasattr(w, "_x")) - - def test_non_keywords(self): - w = self.Wrapped() - self.assertRaises(TypeError, lambda: w.set(0, y=1)) - - def test_kwarg_ownership(self): - # test _input_kwargs is owned by each class instance and not a shared static variable - class Setter(object): - @keyword_only - def set(self, x=None, other=None, other_x=None): - if "other" in self._input_kwargs: - self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) - self._x = self._input_kwargs["x"] - - a = Setter() - b = Setter() - a.set(x=1, other=b, other_x=2) - self.assertEqual(a._x, 1) - self.assertEqual(b._x, 2) - - -class UtilTests(PySparkTestCase): - def test_py4j_exception_message(self): - from pyspark.util import _exception_message - - with self.assertRaises(Py4JJavaError) as context: - # This attempts java.lang.String(null) which throws an NPE. - self.sc._jvm.java.lang.String(None) - - self.assertTrue('NullPointerException' in _exception_message(context.exception)) - - def test_parsing_version_string(self): - from pyspark.util import VersionUtils - self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): - - """General PySpark tests that depend on scipy """ - - def test_serialize(self): - from scipy.special import gammaln - x = range(1, 5) - expected = list(map(gammaln, x)) - observed = self.sc.parallelize(x).map(gammaln).collect() - self.assertEqual(expected, observed) - - -@unittest.skipIf(not _have_numpy, "NumPy not installed") -class NumPyTests(PySparkTestCase): - - """General PySpark tests that depend on numpy """ - - def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) - s = x.stats() - self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) - self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) - - stats_dict = s.asDict() - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) - - stats_sample_dict = s.asDict(sample=True) - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) - self.assertSequenceEqual( - [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) - self.assertSequenceEqual( - [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) - - -if __name__ == "__main__": - from pyspark.tests import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/tests/__init__.py b/python/pyspark/tests/__init__.py new file mode 100644 index 0000000000000..12bdf0d0175b6 --- /dev/null +++ b/python/pyspark/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py new file mode 100644 index 0000000000000..92bcb11561307 --- /dev/null +++ b/python/pyspark/tests/test_appsubmit.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import shutil +import subprocess +import tempfile +import unittest +import zipfile + + +class SparkSubmitTests(unittest.TestCase): + + def setUp(self): + self.programDir = tempfile.mkdtemp() + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] + + def tearDown(self): + shutil.rmtree(self.programDir) + + def createTempFile(self, name, content, dir=None): + """ + Create a temp file with the given name and content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) + with open(path, "w") as f: + f.write(content) + return path + + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): + """ + Create a zip archive containing a file with the given content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() + return path + + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + + def test_single_script(self): + """Submit and test a single script file""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_script_with_local_functions(self): + """Submit and test a single script file calling a global function""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 3 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) + + def test_module_dependency(self): + """Submit and test a script with a dependency on another module""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_module_dependency_on_cluster(self): + """Submit and test a script with a dependency on another module on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", + "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_single_script_on_cluster(self): + """Submit and test a single script on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 2 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_user_configuration(self): + """Make sure user configuration is respected (SPARK-19307)""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkConf, SparkContext + | + |conf = SparkConf().set("spark.test_config", "1") + |sc = SparkContext(conf = conf) + |try: + | if sc._conf.get("spark.test_config") != "1": + | raise Exception("Cannot find spark.test_config in SparkContext's conf.") + |finally: + | sc.stop() + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local", script], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) + + +if __name__ == "__main__": + from pyspark.tests.test_appsubmit import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/tests/test_broadcast.py similarity index 91% rename from python/pyspark/test_broadcast.py rename to python/pyspark/tests/test_broadcast.py index a00329c18ad8f..a98626e8f4bc9 100644 --- a/python/pyspark/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -14,20 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import random import tempfile import unittest -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark.broadcast import Broadcast -from pyspark.conf import SparkConf -from pyspark.context import SparkContext +from pyspark import SparkConf, SparkContext from pyspark.java_gateway import launch_gateway from pyspark.serializers import ChunkedStream @@ -118,9 +110,13 @@ def random_bytes(n): for buffer_length in [1, 2, 5, 8192]: self._test_chunked_stream(random_bytes(data_length), buffer_length) + if __name__ == '__main__': - from pyspark.test_broadcast import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) + from pyspark.tests.test_broadcast import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py new file mode 100644 index 0000000000000..f5a9accc3fe6e --- /dev/null +++ b/python/pyspark/tests/test_conf.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import unittest + +from pyspark import SparkContext, SparkConf + + +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_conf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py new file mode 100644 index 0000000000000..201baf420354d --- /dev/null +++ b/python/pyspark/tests/test_context.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import tempfile +import threading +import time +import unittest + +from pyspark import SparkFiles, SparkContext +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME + + +class CheckpointTests(ReusedPySparkTestCase): + + def setUp(self): + self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual("file:" + self.checkpointDir.name, + os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) + self.assertEqual([1, 2, 3, 4], recovered.collect()) + + +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + +class AddFileTests(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + def func(x): + from userlibrary import UserClass + return UserClass().hello() + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + + def test_add_file_recursively_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello") + self.sc.addFile(path, True) + download_path = SparkFiles.get("hello") + self.assertNotEqual(path, download_path) + with open(download_path + "/hello.txt") as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + with open(download_path + "/sub_hello/sub_hello.txt") as test_file: + self.assertEqual("Sub Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + + def test_add_egg_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlib import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") + self.sc.addPyFile(path) + from userlib import UserClass + self.assertEqual("Hello World from inside a package!", UserClass().hello()) + + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + + +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + + def test_set_conf(self): + # This is for an internal use case. When there is an existing SparkContext, + # SparkSession's builder needs to set configs into SparkContext's conf. + sc = SparkContext() + sc._conf.set("spark.test.SPARK16224", "SPARK16224") + self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") + sc.stop() + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + + +if __name__ == "__main__": + from pyspark.tests.test_context import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py new file mode 100644 index 0000000000000..fccd74fff1516 --- /dev/null +++ b/python/pyspark/tests/test_daemon.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import time +import unittest + +from pyspark.serializers import read_int + + +class DaemonTests(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send(b"\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py") + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + try: + self.connect(port) + except EnvironmentError as exception: + self.assertEqual(exception.errno, ECONNREFUSED) + else: + self.fail("Expected EnvironmentError to be raised") + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + + +if __name__ == "__main__": + from pyspark.tests.test_daemon import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py new file mode 100644 index 0000000000000..e97e695f8b20d --- /dev/null +++ b/python/pyspark/tests/test_join.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class JoinTests(ReusedPySparkTestCase): + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_join import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py new file mode 100644 index 0000000000000..56cbcff01657c --- /dev/null +++ b/python/pyspark/tests/test_profiler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +import unittest + +from pyspark import SparkConf, SparkContext, BasicProfiler +from pyspark.testing.utils import PySparkTestCase + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + + +class ProfilerTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, conf=conf) + + def test_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + old_stdout = sys.stdout + sys.stdout = io = StringIO() + self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 18): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + + +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_profiler import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py new file mode 100644 index 0000000000000..b2a544b8de78a --- /dev/null +++ b/python/pyspark/tests/test_rdd.py @@ -0,0 +1,739 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import hashlib +import os +import random +import sys +import tempfile +from glob import glob + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, RDD +from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\ + MarshalSerializer, UTF8Deserializer, NoOpSerializer +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class RDDTests(ReusedPySparkTestCase): + + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) + + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) + + def test_transforming_cartesian_result(self): + # Regression test for SPARK-1034 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + cart = rdd1.cartesian(rdd2) + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() + + def test_transforming_pickle_file(self): + # Regression test for SPARK-2601 + data = self.sc.parallelize([u"Hello", u"World!"]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsPickleFile(tempFile.name) + pickled_file = self.sc.pickleFile(tempFile.name) + pickled_file.map(lambda x: x).collect() + + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) + + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + + def test_deleting_input_files(self): + # Regression test for SPARK-1025 + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(xrange(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_aggregate_by_key(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) + + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEqual([jon, jane], theDoes.collect()) + + def test_large_broadcast(self): + N = 10000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 27MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = list(range(1 << 15)) + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + + def test_large_closure(self): + N = 200000 + data = [float(i) for i in xrange(N)] + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) + + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + # regression test for SPARK-4841 + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + t = self.sc.textFile(path) + cnt = t.count() + self.assertEqual(cnt, t.zip(t).count()) + rdd = t.map(str) + self.assertEqual(cnt, t.zip(rdd).count()) + # regression test for bug in _reserializer() + self.assertEqual(cnt, t.zip(rdd).count()) + + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(xrange(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + + def test_histogram(self): + # empty + rdd = self.sc.parallelize([]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertRaises(ValueError, lambda: rdd.histogram(1)) + + # out of range + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) + + # in range with one bucket + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) + + # in range with one bucket exact match + self.assertEqual([4], rdd.histogram([1, 4])[1]) + + # out of range with two buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) + + # out of range with two uneven buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + + # in range with two buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two bucket and None + rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two uneven buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) + + # mixed range with two uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) + + # mixed range with four uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # mixed range with uneven buckets and NaN + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, + 199.0, 200.0, 200.1, None, float('nan')]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # out of range with infinite buckets + rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + + # invalid buckets + self.assertRaises(ValueError, lambda: rdd.histogram([])) + self.assertRaises(ValueError, lambda: rdd.histogram([1])) + self.assertRaises(ValueError, lambda: rdd.histogram(0)) + self.assertRaises(TypeError, lambda: rdd.histogram({})) + + # without buckets + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) + + # without buckets single element + rdd = self.sc.parallelize([1]) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) + + # without bucket no range + rdd = self.sc.parallelize([1] * 4) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) + + # without buckets basic two + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + + # without buckets with more requested than elements + rdd = self.sc.parallelize([1, 2]) + buckets = [1 + 0.2 * i for i in range(6)] + hist = [1, 0, 0, 0, 1] + self.assertEqual((buckets, hist), rdd.histogram(5)) + + # invalid RDDs + rdd = self.sc.parallelize([1, float('inf')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + rdd = self.sc.parallelize([float('nan')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + + # string + rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + def test_repartitionAndSortWithinPartitions_asc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) + + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + for x, y in zip(data, result): + self.assertEqual(x, y) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rdd import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py new file mode 100644 index 0000000000000..e45f5b371f461 --- /dev/null +++ b/python/pyspark/tests/test_readwrite.py @@ -0,0 +1,499 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest +from array import array + +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME + + +class InputFormatTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", + "org.apache.hadoop.io.DoubleWritable", + "org.apache.hadoop.io.Text").collect()) + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.assertEqual(doubles, ed) + + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", + "org.apache.hadoop.io.Text", + "org.apache.hadoop.io.Text").collect()) + et = [(u'1', u'aa'), + (u'1', u'aa'), + (u'2', u'aa'), + (u'2', u'bb'), + (u'2', u'bb'), + (u'3', u'cc')] + self.assertEqual(text, et) + + bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.assertEqual(bools, eb) + + nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.assertEqual(nulls, en) + + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + for v in maps: + self.assertTrue(v in em) + + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable").collect()) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) + + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + ).collect()) + self.assertEqual(unbatched_clazz, ec) + + def test_oldhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newolderror(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.sequenceFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.NotValidWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(list(range(100)), result) + + +class OutputFormatTests(ReusedPySparkTestCase): + + def setUp(self): + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) + + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) + + conf = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" + } + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} + result = self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + @unittest.skipIf(sys.version >= "3", "serialize of array") + def test_newhadoop_with_array(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = list(zip(x, y)) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) + + +if __name__ == "__main__": + from pyspark.tests.test_readwrite import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py new file mode 100644 index 0000000000000..bce94062c8af7 --- /dev/null +++ b/python/pyspark/tests/test_serializers.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +import sys +import unittest + +from pyspark import serializers +from pyspark.serializers import * +from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \ + AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \ + FlattenedValuesSerializer, CartesianDeserializer +from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \ + have_numpy, have_scipy + + +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from pickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEqual(p1, p2) + + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 + def test_pickling_file_handles(self): + # to be corrected with SPARK-11160 + try: + import xmlrunner + except ImportError: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) + + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.__code__.co_names) + ser.dumps(foo) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +@unittest.skipIf(not have_scipy, "SciPy not installed") +class SciPyTests(PySparkTestCase): + + """General PySpark tests that depend on scipy """ + + def test_serialize(self): + from scipy.special import gammaln + + x = range(1, 5) + expected = list(map(gammaln, x)) + observed = self.sc.parallelize(x).map(gammaln).collect() + self.assertEqual(expected, observed) + + +@unittest.skipIf(not have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + import numpy as np + + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) + + stats_dict = s.asDict() + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) + + stats_sample_dict = s.asDict(sample=True) + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) + self.assertSequenceEqual( + [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) + self.assertSequenceEqual( + [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +if __name__ == "__main__": + from pyspark.tests.test_serializers import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py new file mode 100644 index 0000000000000..0489426061b75 --- /dev/null +++ b/python/pyspark/tests/test_shuffle.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter + +if sys.version_info[0] >= 3: + xrange = range + + +class MergerTests(unittest.TestCase): + + def setUp(self): + self.N = 1 << 12 + self.l = [i for i in xrange(self.N)] + self.data = list(zip(self.l, self.l)) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 20) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m.items()), + self.N * 10) + m._cleanup() + + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + + +class SorterTests(unittest.TestCase): + def test_in_memory_sort(self): + l = list(range(1024)) + random.shuffle(l) + sorter = ExternalSorter(1024) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + + def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit + l = list(range(1024)) + random.shuffle(l) + sorter = CustomizedSorter(1) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + + def test_external_sort_in_rdd(self): + conf = SparkConf().set("spark.python.worker.memory", "1m") + sc = SparkContext(conf=conf) + l = list(range(10240)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_shuffle import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py new file mode 100644 index 0000000000000..b3a967440a9b2 --- /dev/null +++ b/python/pyspark/tests/test_taskcontext.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import time + +from pyspark import SparkContext, TaskContext, BarrierTaskContext +from pyspark.testing.utils import PySparkTestCase + + +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + self.sc._conf.set("spark.python.work.reuse", "true") + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") + + # worker will be reused in this barrier job + self.test_barrier() + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_taskcontext import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py new file mode 100644 index 0000000000000..11cda8fd2f5cd --- /dev/null +++ b/python/pyspark/tests/test_util.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import keyword_only +from pyspark.testing.utils import PySparkTestCase + + +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + + +if __name__ == "__main__": + from pyspark.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py new file mode 100644 index 0000000000000..a33b77d983419 --- /dev/null +++ b/python/pyspark/tests/test_worker.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import tempfile +import threading +import time + +from py4j.protocol import Py4JJavaError + +from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class WorkerTests(ReusedPySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + + def sleep(x): + import os + import time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + with open(path) as f: + data = f.read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + # run a normal job + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(xrange(100), 1) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(xrange(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) + + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + self.sc.pythonVer = version + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_worker import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From d4130ec1f3461dcc961eee9802005ba7a15212d1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Nov 2018 17:20:49 +0800 Subject: [PATCH 047/145] [SPARK-26014][R] Deprecate R prior to version 3.4 in SparkR ## What changes were proposed in this pull request? This PR proposes to bump up the minimum versions of R from 3.1 to 3.4. R version. 3.1.x is too old. It's released 4.5 years ago. R 3.4.0 is released 1.5 years ago. Considering the timing for Spark 3.0, deprecating lower versions, bumping up R to 3.4 might be reasonable option. It should be good to deprecate and drop < R 3.4 support. ## How was this patch tested? Jenkins tests. Closes #23012 from HyukjinKwon/SPARK-26014. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- R/WINDOWS.md | 2 +- R/pkg/DESCRIPTION | 2 +- R/pkg/inst/profile/general.R | 4 ++++ R/pkg/inst/profile/shell.R | 4 ++++ docs/index.md | 3 ++- 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/R/WINDOWS.md b/R/WINDOWS.md index da668a69b8679..33a4c850cfdac 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -3,7 +3,7 @@ To build SparkR on Windows, the following steps are required 1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to -include Rtools and R in `PATH`. +include Rtools and R in `PATH`. Note that support for R prior to version 3.4 is deprecated as of Spark 3.0.0. 2. Install [JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cdaaa6104e6a9..736da46eaa8d3 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -15,7 +15,7 @@ URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html SystemRequirements: Java (== 8) Depends: - R (>= 3.0), + R (>= 3.1), methods Suggests: knitr, diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8c75c19ca7ac3..3efb460846fc2 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") dirs <- strsplit(packageDir, ",")[[1]] .libPaths(c(dirs, .libPaths())) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 8a8111a8c5419..32eb3671b5941 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + home <- Sys.getenv("SPARK_HOME") .libPaths(c(file.path(home, "R", "lib"), .libPaths())) Sys.setenv(NOAWT = 1) diff --git a/docs/index.md b/docs/index.md index ac38f1d4c53c2..bd287e3f8d83f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,8 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. R prior to version 3.4 support is deprecated as of Spark 3.0.0. +For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). From 44d4ef60b8015fd8701a685cfb7c96c5ea57d3b1 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 15 Nov 2018 18:25:18 +0800 Subject: [PATCH 048/145] [SPARK-25974][SQL] Optimizes Generates bytecode for ordering based on the given order ## What changes were proposed in this pull request? Currently, when generates the code for ordering based on the given order, too many variables and assignment statements will be generated, which is not necessary. This PR will eliminate redundant variables. Optimizes Generates bytecode for ordering based on the given order. The generated code looks like: ``` spark.range(1).selectExpr( "id as key", "(id & 1023) as value1", "cast(id & 1023 as double) as value2", "cast(id & 1023 as int) as value3" ).select("value1", "value2", "value3").orderBy("value1", "value2").collect() ``` before PR(codegen size: 178) ``` Generated Ordering by input[0, bigint, false] ASC NULLS FIRST,input[1, double, false] ASC NULLS FIRST: /* 001 */ public SpecificOrdering generate(Object[] references) { /* 002 */ return new SpecificOrdering(references); /* 003 */ } /* 004 */ /* 005 */ class SpecificOrdering extends org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering { /* 006 */ /* 007 */ private Object[] references; /* 008 */ /* 009 */ /* 010 */ public SpecificOrdering(Object[] references) { /* 011 */ this.references = references; /* 012 */ /* 013 */ } /* 014 */ /* 015 */ public int compare(InternalRow a, InternalRow b) { /* 016 */ /* 017 */ InternalRow i = null; /* 018 */ /* 019 */ i = a; /* 020 */ boolean isNullA_0; /* 021 */ long primitiveA_0; /* 022 */ { /* 023 */ long value_0 = i.getLong(0); /* 024 */ isNullA_0 = false; /* 025 */ primitiveA_0 = value_0; /* 026 */ } /* 027 */ i = b; /* 028 */ boolean isNullB_0; /* 029 */ long primitiveB_0; /* 030 */ { /* 031 */ long value_0 = i.getLong(0); /* 032 */ isNullB_0 = false; /* 033 */ primitiveB_0 = value_0; /* 034 */ } /* 035 */ if (isNullA_0 && isNullB_0) { /* 036 */ // Nothing /* 037 */ } else if (isNullA_0) { /* 038 */ return -1; /* 039 */ } else if (isNullB_0) { /* 040 */ return 1; /* 041 */ } else { /* 042 */ int comp = (primitiveA_0 > primitiveB_0 ? 1 : primitiveA_0 < primitiveB_0 ? -1 : 0); /* 043 */ if (comp != 0) { /* 044 */ return comp; /* 045 */ } /* 046 */ } /* 047 */ /* 048 */ i = a; /* 049 */ boolean isNullA_1; /* 050 */ double primitiveA_1; /* 051 */ { /* 052 */ double value_1 = i.getDouble(1); /* 053 */ isNullA_1 = false; /* 054 */ primitiveA_1 = value_1; /* 055 */ } /* 056 */ i = b; /* 057 */ boolean isNullB_1; /* 058 */ double primitiveB_1; /* 059 */ { /* 060 */ double value_1 = i.getDouble(1); /* 061 */ isNullB_1 = false; /* 062 */ primitiveB_1 = value_1; /* 063 */ } /* 064 */ if (isNullA_1 && isNullB_1) { /* 065 */ // Nothing /* 066 */ } else if (isNullA_1) { /* 067 */ return -1; /* 068 */ } else if (isNullB_1) { /* 069 */ return 1; /* 070 */ } else { /* 071 */ int comp = org.apache.spark.util.Utils.nanSafeCompareDoubles(primitiveA_1, primitiveB_1); /* 072 */ if (comp != 0) { /* 073 */ return comp; /* 074 */ } /* 075 */ } /* 076 */ /* 077 */ /* 078 */ return 0; /* 079 */ } /* 080 */ /* 081 */ /* 082 */ } ``` After PR(codegen size: 89) ``` Generated Ordering by input[0, bigint, false] ASC NULLS FIRST,input[1, double, false] ASC NULLS FIRST: /* 001 */ public SpecificOrdering generate(Object[] references) { /* 002 */ return new SpecificOrdering(references); /* 003 */ } /* 004 */ /* 005 */ class SpecificOrdering extends org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering { /* 006 */ /* 007 */ private Object[] references; /* 008 */ /* 009 */ /* 010 */ public SpecificOrdering(Object[] references) { /* 011 */ this.references = references; /* 012 */ /* 013 */ } /* 014 */ /* 015 */ public int compare(InternalRow a, InternalRow b) { /* 016 */ /* 017 */ /* 018 */ long value_0 = a.getLong(0); /* 019 */ long value_2 = b.getLong(0); /* 020 */ if (false && false) { /* 021 */ // Nothing /* 022 */ } else if (false) { /* 023 */ return -1; /* 024 */ } else if (false) { /* 025 */ return 1; /* 026 */ } else { /* 027 */ int comp = (value_0 > value_2 ? 1 : value_0 < value_2 ? -1 : 0); /* 028 */ if (comp != 0) { /* 029 */ return comp; /* 030 */ } /* 031 */ } /* 032 */ /* 033 */ double value_1 = a.getDouble(1); /* 034 */ double value_3 = b.getDouble(1); /* 035 */ if (false && false) { /* 036 */ // Nothing /* 037 */ } else if (false) { /* 038 */ return -1; /* 039 */ } else if (false) { /* 040 */ return 1; /* 041 */ } else { /* 042 */ int comp = org.apache.spark.util.Utils.nanSafeCompareDoubles(value_1, value_3); /* 043 */ if (comp != 0) { /* 044 */ return comp; /* 045 */ } /* 046 */ } /* 047 */ /* 048 */ /* 049 */ return 0; /* 050 */ } /* 051 */ /* 052 */ /* 053 */ } ``` ## How was this patch tested? the existed test cases. Closes #22976 from heary-cao/GenArrayData. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../codegen/GenerateOrdering.scala | 113 ++++++++---------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 9a51be6ed5aeb..c3b95b6c67fdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -68,62 +68,55 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR genComparisons(ctx, ordering) } + /** + * Creates the variables for ordering based on the given order. + */ + private def createOrderKeys( + ctx: CodegenContext, + row: String, + ordering: Seq[SortOrder]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + // to use INPUT_ROW we must make sure currentVars is null + ctx.currentVars = null + ordering.map(_.child.genCode(ctx)) + } + /** * Generates the code for ordering based on the given order. */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val oldInputRow = ctx.INPUT_ROW val oldCurrentVars = ctx.currentVars - val inputRow = "i" - ctx.INPUT_ROW = inputRow - // to use INPUT_ROW we must make sure currentVars is null - ctx.currentVars = null - - val comparisons = ordering.map { order => - val eval = order.child.genCode(ctx) - val asc = order.isAscending - val isNullA = ctx.freshName("isNullA") - val primitiveA = ctx.freshName("primitiveA") - val isNullB = ctx.freshName("isNullB") - val primitiveB = ctx.freshName("primitiveB") + val rowAKeys = createOrderKeys(ctx, "a", ordering) + val rowBKeys = createOrderKeys(ctx, "b", ordering) + val comparisons = rowAKeys.zip(rowBKeys).zipWithIndex.map { case ((l, r), i) => + val dt = ordering(i).child.dataType + val asc = ordering(i).isAscending + val nullOrdering = ordering(i).nullOrdering + val lRetValue = nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + } + val rRetValue = nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + } s""" - ${ctx.INPUT_ROW} = a; - boolean $isNullA; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; - { - ${eval.code} - $isNullA = ${eval.isNull}; - $primitiveA = ${eval.value}; - } - ${ctx.INPUT_ROW} = b; - boolean $isNullB; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; - { - ${eval.code} - $isNullB = ${eval.isNull}; - $primitiveB = ${eval.value}; - } - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return ${ - order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" - }}; - } else if ($isNullB) { - return ${ - order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" - }}; - } else { - int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - } - } - """ + |${l.code} + |${r.code} + |if (${l.isNull} && ${r.isNull}) { + | // Nothing + |} else if (${l.isNull}) { + | return $lRetValue; + |} else if (${r.isNull}) { + | return $rRetValue; + |} else { + | int comp = ${ctx.genComp(dt, l.value, r.value)}; + | if (comp != 0) { + | return ${if (asc) "comp" else "-comp"}; + | } + |} + """.stripMargin } val code = ctx.splitExpressions( @@ -133,30 +126,26 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR returnType = "int", makeSplitFunction = { body => s""" - InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. - $body - return 0; - """ + |$body + |return 0; + """.stripMargin }, foldFunctions = { funCalls => funCalls.zipWithIndex.map { case (funCall, i) => val comp = ctx.freshName("comp") s""" - int $comp = $funCall; - if ($comp != 0) { - return $comp; - } - """ + |int $comp = $funCall; + |if ($comp != 0) { + | return $comp; + |} + """.stripMargin }.mkString }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow // make sure INPUT_ROW is declared even if splitExpressions // returns an inlined block - s""" - |InternalRow $inputRow = null; - |$code - """.stripMargin + code } protected def create(ordering: Seq[SortOrder]): BaseOrdering = { From b46f75a5af372422de0f8e07ff920fa6ccd33c7e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 15 Nov 2018 20:09:53 +0800 Subject: [PATCH 049/145] [SPARK-26057][SQL] Transform also analyzed plans when dedup references ## What changes were proposed in this pull request? In SPARK-24865 `AnalysisBarrier` was removed and in order to improve resolution speed, the `analyzed` flag was (re-)introduced in order to process only plans which are not yet analyzed. This should not be the case when performing attribute deduplication as in that case we need to transform also the plans which were already analyzed, otherwise we can miss to rewrite some attributes leading to invalid plans. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23035 from mgaido91/SPARK-26057. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c4e526081f4a2..ab2312fdcdeef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -871,7 +871,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan resolveOperatorsDown { case currentFragment => + plan transformDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2bb18f48e0ae2..0ee2627814ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2554,4 +2554,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(swappedDf.filter($"key"($"map") > "a"), Row(2, Map(2 -> "b"))) } + + test("SPARK-26057: attribute deduplication on already analyzed plans") { + withTempView("a", "b", "v") { + val df1 = Seq(("1-1", 6)).toDF("id", "n") + df1.createOrReplaceTempView("a") + val df3 = Seq("1-1").toDF("id") + df3.createOrReplaceTempView("b") + spark.sql( + """ + |SELECT a.id, n as m + |FROM a + |WHERE EXISTS( + | SELECT 1 + | FROM b + | WHERE b.id = a.id) + """.stripMargin).createOrReplaceTempView("v") + val res = spark.sql( + """ + |SELECT a.id, n, m + | FROM a + | LEFT OUTER JOIN v ON v.id = a.id + """.stripMargin) + checkAnswer(res, Row("1-1", 6, 6)) + } + } } From 9610efc252c94f93689d45e320df1c5815d97b25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Nov 2018 20:25:27 +0800 Subject: [PATCH 050/145] [SPARK-26055][CORE] InterfaceStability annotations should be retained at runtime ## What changes were proposed in this pull request? It's good to have annotations available at runtime, so that tools like MiMa can detect them and deal with then specially. e.g. we don't want to track compatibility for unstable classes. This PR makes `InterfaceStability` annotations to be retained at runtime, to be consistent with `Experimental` and `DeveloperApi` ## How was this patch tested? N/A Closes #23029 from cloud-fan/annotation. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../org/apache/spark/annotation/DeveloperApi.java | 1 + .../org/apache/spark/annotation/Experimental.java | 1 + .../apache/spark/annotation/InterfaceStability.java | 11 ++++++++++- .../java/org/apache/spark/annotation/Private.java | 6 ++---- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java index 0ecef6db0e039..890f2faca28b0 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java @@ -29,6 +29,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java index ff8120291455f..96875920cd9c3 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java @@ -30,6 +30,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java index 323098f69c6e1..02bcec737e80e 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java @@ -17,7 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.Documented; +import java.lang.annotation.*; /** * Annotation to inform users of how much to rely on a particular package, @@ -31,6 +31,9 @@ public class InterfaceStability { * (e.g. from 1.0 to 2.0). */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Stable {}; /** @@ -38,6 +41,9 @@ public class InterfaceStability { * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Evolving {}; /** @@ -45,5 +51,8 @@ public class InterfaceStability { * Classes that are unannotated are considered Unstable. */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Unstable {}; } diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Private.java b/common/tags/src/main/java/org/apache/spark/annotation/Private.java index 9082fcf0c84bc..a460d608ae16b 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Private.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Private.java @@ -17,10 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; +import java.lang.annotation.*; /** * A class that is considered private to the internals of Spark -- there is a high-likelihood @@ -35,6 +32,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) From 91405b3b6eb4fa8047123d951859b6e2a1e46b6a Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 15 Nov 2018 09:22:31 -0600 Subject: [PATCH 051/145] [SPARK-22450][WIP][CORE][MLLIB][FOLLOWUP] Safely register MultivariateGaussian ## What changes were proposed in this pull request? register following classes in Kryo: "org.apache.spark.ml.stat.distribution.MultivariateGaussian", "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" ## How was this patch tested? added tests Due to existing module dependency, I can not import spark-core in mllib-local's testsuits, so I do not add testsuite in `org.apache.spark.ml.stat.distribution.MultivariateGaussianSuite`. And I notice that class `ClusterStats` in `ClusteringEvaluator` is registered in a different way, should it be modified to keep in line with others in ML? srowen Closes #22974 from zhengruifeng/kryo_MultivariateGaussian. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../spark/serializer/KryoSerializer.scala | 10 ++++++++- .../distribution/MultivariateGaussian.scala | 4 ++-- .../distribution/MultivariateGaussian.scala | 4 ++-- .../spark/ml/attribute/AttributeSuite.scala | 19 +++++++++++++++- .../MultivariateGaussianSuite.scala | 22 ++++++++++++++++++- 5 files changed, 52 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3795d5c3b38e3..66812a54846c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -215,6 +215,12 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. Seq( + "org.apache.spark.ml.attribute.Attribute", + "org.apache.spark.ml.attribute.AttributeGroup", + "org.apache.spark.ml.attribute.BinaryAttribute", + "org.apache.spark.ml.attribute.NominalAttribute", + "org.apache.spark.ml.attribute.NumericAttribute", + "org.apache.spark.ml.feature.Instance", "org.apache.spark.ml.feature.LabeledPoint", "org.apache.spark.ml.feature.OffsetInstance", @@ -224,6 +230,7 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.stat.distribution.MultivariateGaussian", "org.apache.spark.ml.tree.impl.TreePoint", "org.apache.spark.mllib.clustering.VectorWithNorm", "org.apache.spark.mllib.linalg.DenseMatrix", @@ -232,7 +239,8 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.mllib.linalg.SparseMatrix", "org.apache.spark.mllib.linalg.SparseVector", "org.apache.spark.mllib.linalg.Vector", - "org.apache.spark.mllib.regression.LabeledPoint" + "org.apache.spark.mllib.regression.LabeledPoint", + "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" ).foreach { name => try { val clazz = Utils.classForName(name) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index 3167e0c286d47..e7f7a8e07d7f2 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -48,14 +48,14 @@ class MultivariateGaussian @Since("2.0.0") ( this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov)) } - private val breezeMu = mean.asBreeze.toDenseVector + @transient private lazy val breezeMu = mean.asBreeze.toDenseVector /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 4cf662e036346..9a746dcf35556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -43,7 +43,7 @@ class MultivariateGaussian @Since("1.3.0") ( require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - private val breezeMu = mu.asBreeze.toDenseVector + @transient private lazy val breezeMu = mu.asBreeze.toDenseVector /** * private[mllib] constructor @@ -60,7 +60,7 @@ class MultivariateGaussian @Since("1.3.0") ( * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6355e0f179496..eb5f3ca45940d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.attribute -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.types._ class AttributeSuite extends SparkFunSuite { @@ -221,4 +222,20 @@ class AttributeSuite extends SparkFunSuite { val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } + + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val numericAttr = new NumericAttribute(Some("numeric"), Some(1), Some(1.0), Some(2.0)) + val nominalAttr = new NominalAttribute(Some("nominal"), Some(2), Some(false)) + val binaryAttr = new BinaryAttribute(Some("binary"), Some(3), Some(Array("i", "j"))) + + Seq(numericAttr, nominalAttr, binaryAttr).foreach { i => + val i2 = ser.deserialize[Attribute](ser.serialize(i)) + assert(i === i2) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 669d44223d713..5b4a2607f0b25 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.stat.distribution -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { @@ -80,4 +81,23 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) } + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val mu = Vectors.dense(0.0, 0.0) + val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) + val dist1 = new MultivariateGaussian(mu, sigma1) + + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) + val dist2 = new MultivariateGaussian(mu, sigma2) + + Seq(dist1, dist2).foreach { i => + val i2 = ser.deserialize[MultivariateGaussian](ser.serialize(i)) + assert(i.sigma === i2.sigma) + assert(i.mu === i2.mu) + } + } } From cae5879dbe5881a88c4925f6c5408f32d6f3860e Mon Sep 17 00:00:00 2001 From: Shahid Date: Thu, 15 Nov 2018 10:27:57 -0600 Subject: [PATCH 052/145] [SPARK-26044][WEBUI] Aggregated Metrics table sort based on executor ID ## What changes were proposed in this pull request? Aggregated Metrics table in the stage page is not sorted based on the executorID properly. Because executorID is string and also the logs of the executors are in the same column. In this PR, I created a new column for executor logs. ## How was this patch tested? Before patch: ![screenshot from 2018-11-14 02-05-12](https://user-images.githubusercontent.com/23054875/48441529-caa77580-e7b1-11e8-90ea-b16f63438102.png) After patch: ![screenshot from 2018-11-14 02-05-29](https://user-images.githubusercontent.com/23054875/48441540-d2671a00-e7b1-11e8-9059-890bfe80c961.png) Closes #23024 from shahidki31/AggSort. Authored-by: Shahid Signed-off-by: Sean Owen --- .../apache/spark/ui/jobs/ExecutorTable.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 0ff64f053f371..1be81e5ef9952 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -70,6 +70,7 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { Blacklisted + Logs {createExecutorTable(stage)} @@ -92,16 +93,7 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) - -
    {k}
    -
    - { - executor.map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - } - } -
    - + {k} {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks + v.killedTasks} @@ -145,6 +137,11 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { false } } + {executor.map(_.executorLogs).getOrElse(Map.empty).map { + case (logName, logUrl) => + }} + + } } From 9a5fda60e532dc7203d21d5fbe385cd561906ccb Mon Sep 17 00:00:00 2001 From: Shanyu Zhao Date: Thu, 15 Nov 2018 10:30:16 -0600 Subject: [PATCH 053/145] [SPARK-26011][SPARK-SUBMIT] Yarn mode pyspark app without python main resource does not honor "spark.jars.packages" SparkSubmit determines pyspark app by the suffix of primary resource but Livy uses "spark-internal" as the primary resource when calling spark-submit, therefore args.isPython is set to false in SparkSubmit.scala. In Yarn mode, SparkSubmit module is responsible for resolving maven coordinates and adding them to "spark.submit.pyFiles" so that python's system path can be set correctly. The fix is to resolve maven coordinates not only when args.isPython is true, but also when primary resource is spark-internal. Tested the patch with Livy submitting pyspark app, spark-submit, pyspark with or without packages config. Signed-off-by: Shanyu Zhao Closes #23009 from shanyu/shanyu-26011. Authored-by: Shanyu Zhao Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0fc8c9bd789e0..324f6f8894d34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -318,7 +318,7 @@ private[spark] class SparkSubmit extends Logging { if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) - if (args.isPython) { + if (args.isPython || isInternal(args.primaryResource)) { args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } From 3649fe599f1aa27fea0abd61c18d3ffa275d267b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Nov 2018 07:58:09 +0800 Subject: [PATCH 054/145] [SPARK-26035][PYTHON] Break large streaming/tests.py files into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/streaming/tests.py` into ...: ``` pyspark ├── __init__.py ... ├── streaming │   ├── __init__.py ... │   ├── tests │   │   ├── __init__.py │   │   ├── test_context.py │   │   ├── test_dstream.py │   │   ├── test_kinesis.py │   │   └── test_listener.py ... ├── testing ... │   ├── streamingutils.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23034 from HyukjinKwon/SPARK-26035. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 7 +- python/pyspark/streaming/tests/__init__.py | 16 + .../pyspark/streaming/tests/test_context.py | 184 ++++++ .../{tests.py => tests/test_dstream.py} | 575 +----------------- .../pyspark/streaming/tests/test_kinesis.py | 89 +++ .../pyspark/streaming/tests/test_listener.py | 158 +++++ python/pyspark/testing/streamingutils.py | 190 ++++++ 7 files changed, 658 insertions(+), 561 deletions(-) create mode 100644 python/pyspark/streaming/tests/__init__.py create mode 100644 python/pyspark/streaming/tests/test_context.py rename python/pyspark/streaming/{tests.py => tests/test_dstream.py} (50%) create mode 100644 python/pyspark/streaming/tests/test_kinesis.py create mode 100644 python/pyspark/streaming/tests/test_listener.py create mode 100644 python/pyspark/testing/streamingutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d5fcc060616f2..58b48f43f6468 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -398,8 +398,13 @@ def __hash__(self): "python/pyspark/streaming" ], python_test_goals=[ + # doctests "pyspark.streaming.util", - "pyspark.streaming.tests", + # unittests + "pyspark.streaming.tests.test_context", + "pyspark.streaming.tests.test_dstream", + "pyspark.streaming.tests.test_kinesis", + "pyspark.streaming.tests.test_listener", ] ) diff --git a/python/pyspark/streaming/tests/__init__.py b/python/pyspark/streaming/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/streaming/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py new file mode 100644 index 0000000000000..b44121462a920 --- /dev/null +++ b/python/pyspark/streaming/tests/test_context.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import tempfile +import time + +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + setupCalled = False + + def _add_input_stream(self): + inputs = [range(1, x) for x in range(101)] + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.ssc.stop(False) + + def test_queue_stream(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], result) + + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", bytes(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) + + def test_union(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + def test_await_termination_or_timeout(self): + self._add_input_stream() + self.ssc.start() + self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) + self.ssc.stop(False) + self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_context import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests/test_dstream.py similarity index 50% rename from python/pyspark/streaming/tests.py rename to python/pyspark/streaming/tests/test_dstream.py index 8df00bc988430..d14e346b7a688 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -14,155 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import glob -import os -import sys -from itertools import chain -import time import operator -import tempfile -import random -import struct +import os import shutil +import tempfile +import time +import unittest from functools import reduce +from itertools import chain -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version >= "3": - long = int - -from pyspark.context import SparkConf, SparkContext, RDD -from pyspark.storagelevel import StorageLevel -from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream -from pyspark.streaming.listener import StreamingListener - - -class PySparkStreamingTestCase(unittest.TestCase): - - timeout = 30 # seconds - duration = .5 - - @classmethod - def setUpClass(cls): - class_name = cls.__name__ - conf = SparkConf().set("spark.default.parallelism", 1) - cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir(tempfile.mkdtemp()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - # Clean up in the JVM just in case there has been some issues in Python API - try: - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() - except: - pass - - def setUp(self): - self.ssc = StreamingContext(self.sc, self.duration) - - def tearDown(self): - if self.ssc is not None: - self.ssc.stop(False) - # Clean up in the JVM just in case there has been some issues in Python API - try: - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop(False) - except: - pass - - def wait_for(self, result, n): - start_time = time.time() - while len(result) < n and time.time() - start_time < self.timeout: - time.sleep(0.01) - if len(result) < n: - print("timeout after", self.timeout) - - def _take(self, dstream, n): - """ - Return the first `n` elements in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - dstream.foreachRDD(take) - - self.ssc.start() - self.wait_for(results, n) - return results - - def _collect(self, dstream, n, block=True): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - if rdd and len(result) < n: - r = rdd.collect() - if r: - result.append(r) - - dstream.foreachRDD(get_output) - - if not block: - return result - - self.ssc.start() - self.wait_for(result, n) - return result - - def _test_func(self, input, func, expected, sort=False, input2=None): - """ - @param input: dataset for the test. This should be list of lists. - @param func: wrapped function. This function should return PythonDStream object. - @param expected: expected output for this testcase. - """ - if not isinstance(input[0], RDD): - input = [self.sc.parallelize(d, 1) for d in input] - input_stream = self.ssc.queueStream(input) - if input2 and not isinstance(input2[0], RDD): - input2 = [self.sc.parallelize(d, 1) for d in input2] - input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None - - # Apply test function to stream. - if input2: - stream = func(input_stream, input_stream2) - else: - stream = func(input_stream) - - result = self._collect(stream, len(expected)) - if sort: - self._sort_result_based_on_key(result) - self._sort_result_based_on_key(expected) - self.assertEqual(expected, result) - - def _sort_result_based_on_key(self, outputs): - """Sort the list based on first value.""" - for output in outputs: - output.sort(key=lambda x: x[0]) +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase class BasicOperationTests(PySparkStreamingTestCase): @@ -526,135 +389,6 @@ def failed_func(i): self.fail("a failed func should throw an error") -class StreamingListenerTests(PySparkStreamingTestCase): - - duration = .5 - - class BatchInfoCollector(StreamingListener): - - def __init__(self): - super(StreamingListener, self).__init__() - self.batchInfosCompleted = [] - self.batchInfosStarted = [] - self.batchInfosSubmitted = [] - self.streamingStartedTime = [] - - def onStreamingStarted(self, streamingStarted): - self.streamingStartedTime.append(streamingStarted.time) - - def onBatchSubmitted(self, batchSubmitted): - self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) - - def onBatchStarted(self, batchStarted): - self.batchInfosStarted.append(batchStarted.batchInfo()) - - def onBatchCompleted(self, batchCompleted): - self.batchInfosCompleted.append(batchCompleted.batchInfo()) - - def test_batch_info_reports(self): - batch_collector = self.BatchInfoCollector() - self.ssc.addStreamingListener(batch_collector) - input = [[1], [2], [3], [4]] - - def func(dstream): - return dstream.map(int) - expected = [[1], [2], [3], [4]] - self._test_func(input, func, expected) - - batchInfosSubmitted = batch_collector.batchInfosSubmitted - batchInfosStarted = batch_collector.batchInfosStarted - batchInfosCompleted = batch_collector.batchInfosCompleted - streamingStartedTime = batch_collector.streamingStartedTime - - self.wait_for(batchInfosCompleted, 4) - - self.assertEqual(len(streamingStartedTime), 1) - - self.assertGreaterEqual(len(batchInfosSubmitted), 4) - for info in batchInfosSubmitted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertEqual(info.schedulingDelay(), -1) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosStarted), 4) - for info in batchInfosStarted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosCompleted), 4) - for info in batchInfosCompleted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), 0) - self.assertGreaterEqual(outputInfo.endTime(), 0) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertGreaterEqual(info.processingDelay(), 0) - self.assertGreaterEqual(info.totalDelay(), 0) - self.assertEqual(info.numRecords(), 0) - - class WindowFunctionTests(PySparkStreamingTestCase): timeout = 15 @@ -732,156 +466,6 @@ def func(dstream): self._test_func(input, func, expected) -class StreamingContextTests(PySparkStreamingTestCase): - - duration = 0.1 - setupCalled = False - - def _add_input_stream(self): - inputs = [range(1, x) for x in range(101)] - stream = self.ssc.queueStream(inputs) - self._collect(stream, 1, block=False) - - def test_stop_only_streaming_context(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) - - def test_stop_multiple_times(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.ssc.stop(False) - - def test_queue_stream(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - result = self._collect(dstream, 3) - self.assertEqual(input, result) - - def test_text_file_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream2 = self.ssc.textFileStream(d).map(int) - result = self._collect(dstream2, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "w") as f: - f.writelines(["%d\n" % i for i in range(10)]) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], result) - - def test_binary_records_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream = self.ssc.binaryRecordsStream(d, 10).map( - lambda v: struct.unpack("10b", bytes(v))) - result = self._collect(dstream, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "wb") as f: - f.write(bytearray(range(10))) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) - - def test_union(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - dstream2 = self.ssc.queueStream(input) - dstream3 = self.ssc.union(dstream, dstream2) - result = self._collect(dstream3, 3) - expected = [i * 2 for i in input] - self.assertEqual(expected, result) - - def test_transform(self): - dstream1 = self.ssc.queueStream([[1]]) - dstream2 = self.ssc.queueStream([[2]]) - dstream3 = self.ssc.queueStream([[3]]) - - def func(rdds): - rdd1, rdd2, rdd3 = rdds - return rdd2.union(rdd3).union(rdd1) - - dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - - self.assertEqual([2, 3, 1], self._take(dstream, 3)) - - def test_transform_pairrdd(self): - # This regression test case is for SPARK-17756. - dstream = self.ssc.queueStream( - [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) - self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) - - def test_get_active(self): - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that getActive() returns the active context - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - - # Verify that getActive() returns None - self.ssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - def test_get_active_or_create(self): - # Test StreamingContext.getActiveOrCreate() without checkpoint data - # See CheckpointTests for tests with checkpoint data - self.ssc = None - self.assertEqual(StreamingContext.getActive(), None) - - def setupFunc(): - ssc = StreamingContext(self.sc, self.duration) - ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.setupCalled = True - return ssc - - # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that getActiveOrCreate() returns active context and does not call the setupFunc - self.ssc.start() - self.setupCalled = False - self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) - self.assertFalse(self.setupCalled) - - # Verify that getActiveOrCreate() calls setupFunc after active context is stopped - self.ssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - def test_await_termination_or_timeout(self): - self._add_input_stream() - self.ssc.start() - self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) - self.ssc.stop(False) - self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) - - class CheckpointTests(unittest.TestCase): setupCalled = False @@ -1046,140 +630,11 @@ def check_output(n): self.ssc.stop(True, True) -class KinesisStreamTests(PySparkStreamingTestCase): - - def test_kinesis_stream_api(self): - # Don't start the StreamingContext because we cannot test it in Jenkins - kinesisStream1 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) - kinesisStream2 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - def test_kinesis_stream(self): - if not are_kinesis_tests_enabled: - sys.stderr.write( - "Skipped test_kinesis_stream (enable by setting environment variable %s=1" - % kinesis_test_environ_var) - return - - import random - kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) - try: - kinesisTestUtils.createStream() - aWSCredentials = kinesisTestUtils.getAWSCredentials() - stream = KinesisUtils.createStream( - self.ssc, kinesisAppName, kinesisTestUtils.streamName(), - kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), - InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) - - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - stream.foreachRDD(get_output) - self.ssc.start() - - testData = [i for i in range(1, 11)] - expectedOutput = set([str(i) for i in testData]) - start_time = time.time() - while time.time() - start_time < 120: - kinesisTestUtils.pushData(testData) - if expectedOutput == set(outputBuffer): - break - time.sleep(10) - self.assertEqual(expectedOutput, set(outputBuffer)) - except: - import traceback - traceback.print_exc() - raise - finally: - self.ssc.stop(False) - kinesisTestUtils.deleteStream() - kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) - - -# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because -# the artifact jars are in different directories. -def search_jar(dir, name_prefix): - # We should ignore the following jars - ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") - jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build - glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build - return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] - - -def _kinesis_asl_assembly_dir(): - SPARK_HOME = os.environ["SPARK_HOME"] - return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - - -def search_kinesis_asl_assembly_jar(): - jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") - if not jars: - return None - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py -kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" -are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' - if __name__ == "__main__": - from pyspark.streaming.tests import * - kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - - if kinesis_asl_assembly_jar is None: - kinesis_jar_present = False - jars_args = "" - else: - kinesis_jar_present = True - jars_args = "--jars %s" % kinesis_asl_assembly_jar - - existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") - os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) - testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - StreamingListenerTests] - - if kinesis_jar_present is True: - testcases.append(KinesisStreamTests) - elif are_kinesis_tests_enabled is False: - sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " - "streaming-kinesis-asl-assembly/assembly' or " - "'build/mvn -Pkinesis-asl package' before running this test.") - else: - raise Exception( - ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % _kinesis_asl_assembly_dir()) + - "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/package streaming-kinesis-asl-assembly/assembly'" - "or 'build/mvn -Pkinesis-asl package' before running this test.") - - sys.stderr.write("Running tests: %s \n" % (str(testcases))) - failed = False - for testcase in testcases: - sys.stderr.write("[Running %s]\n" % (testcase)) - tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) - if not result.wasSuccessful(): - failed = True - else: - result = unittest.TextTestRunner(verbosity=2).run(tests) - if not result.wasSuccessful(): - failed = True - sys.exit(failed) + from pyspark.streaming.tests.test_dstream import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py new file mode 100644 index 0000000000000..d8a0b47f04097 --- /dev/null +++ b/python/pyspark/streaming/tests/test_kinesis.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest + +from pyspark import StorageLevel +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.testing.streamingutils import should_test_kinesis, kinesis_requirement_message, \ + PySparkStreamingTestCase + + +@unittest.skipIf(not should_test_kinesis, kinesis_requirement_message) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + self.ssc.stop(False) + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + +if __name__ == "__main__": + from pyspark.streaming.tests.test_kinesis import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py new file mode 100644 index 0000000000000..7c874b6b32500 --- /dev/null +++ b/python/pyspark/streaming/tests/test_listener.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.streaming import StreamingListener +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime + + self.wait_for(batchInfosCompleted, 4) + + self.assertEqual(len(streamingStartedTime), 1) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_listener import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py new file mode 100644 index 0000000000000..85a2fa14b936c --- /dev/null +++ b/python/pyspark/testing/streamingutils.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import glob +import os +import tempfile +import time +import unittest + +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext + + +def search_kinesis_asl_assembly_jar(): + kinesis_asl_assembly_dir = os.path.join( + os.environ["SPARK_HOME"], "external/kinesis-asl-assembly") + + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + + # Search jar in the project dir using the jar name_prefix for both sbt build and maven + # build because the artifact jars are in different directories. + name_prefix = "spark-streaming-kinesis-asl-assembly" + sbt_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/scala-*/%s-*.jar" % name_prefix)) + maven_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/%s_*.jar" % name_prefix)) + jar_paths = sbt_build + maven_build + jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] + + if not jars: + return None + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1' + +if should_skip_kinesis_tests: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' " + "was not set.") +else: + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + if kinesis_asl_assembly_jar is None: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % kinesis_asl_assembly_jar + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) + kinesis_requirement_message = None + +should_test_kinesis = kinesis_requirement_message is None + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 30 # seconds + duration = .5 + + @classmethod + def setUpClass(cls): + class_name = cls.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + try: + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + except: + pass + + def setUp(self): + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + try: + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) + except: + pass + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print("timeout after", self.timeout) + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) From dad2d826ae9138f06751e5d092531a9e06028c21 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 16 Nov 2018 12:46:57 +0800 Subject: [PATCH 055/145] [SPARK-23207][SQL][FOLLOW-UP] Use `SQLConf.get.enableRadixSort` instead of `SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)`. ## What changes were proposed in this pull request? This is a follow-up of #20393. We should read the conf `"spark.sql.sort.enableRadixSort"` from `SQLConf` instead of `SparkConf`, i.e., use `SQLConf.get.enableRadixSort` instead of `SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)`, otherwise the config is never read. ## How was this patch tested? Existing tests. Closes #23046 from ueshin/issues/SPARK-23207/conf. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 3b6eebd41e886..d6742ab3e0f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -280,7 +280,7 @@ object ShuffleExchangeExec { } // The comparator for comparing row hashcode, which should always be Integer. val prefixComparator = PrefixComparators.LONG - val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + val canUseRadixSort = SQLConf.get.enableRadixSort // The prefix computer generates row hashcode as the prefix, so we may decrease the // probability that the prefixes are equal when input rows choose column values from a // limited range. From 4ac8f9becda42e83131df87c68bcd1b0dfb50ac8 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 16 Nov 2018 13:10:44 +0800 Subject: [PATCH 056/145] [SPARK-26073][SQL][FOLLOW-UP] remove invalid comment as we don't use it anymore ## What changes were proposed in this pull request? remove invalid comment as we don't use it anymore More details: https://github.com/apache/spark/pull/22976#discussion_r233764857 ## How was this patch tested? N/A Closes #23044 from heary-cao/followUpOrdering. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/codegen/GenerateOrdering.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index c3b95b6c67fdd..283fd2a6e9383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -143,8 +143,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow - // make sure INPUT_ROW is declared even if splitExpressions - // returns an inlined block code } From 2aef79a65a145b76a88f1d4d9367091fd238b949 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 16 Nov 2018 08:53:29 -0600 Subject: [PATCH 057/145] [SPARK-25023] More detailed security guidance for K8S ## What changes were proposed in this pull request? Highlights specific security issues to be aware of with Spark on K8S and recommends K8S mechanisms that should be used to secure clusters. ## How was this patch tested? N/A - Documentation only CC felixcheung tgravescs skonto Closes #23013 from rvesse/SPARK-25023. Authored-by: Rob Vesse Signed-off-by: Sean Owen --- docs/running-on-kubernetes.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 905226877720a..a7b6fd12a3e5f 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -15,7 +15,19 @@ container images and entrypoints.** # Security Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. -Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. +Please see [Spark Security](security.html) and the specific advice below before running Spark. + +## User Identity + +Images built from the project provided Dockerfiles do not contain any [`USER`](https://docs.docker.com/engine/reference/builder/#user) directives. This means that the resulting images will be running the Spark processes as `root` inside the container. On unsecured clusters this may provide an attack vector for privilege escalation and container breakout. Therefore security conscious deployments should consider providing custom images with `USER` directives specifying an unprivileged UID and GID. + +Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as. + +## Volume Mounts + +As described later in this document under [Using Kubernetes Volumes](#using-kubernetes-volumes) Spark on K8S provides configuration options that allow for mounting certain volume types into the driver and executor pods. In particular it allows for [`hostPath`](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volumes which as described in the Kubernetes documentation have known security vulnerabilities. + +Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/) to limit the ability to mount `hostPath` volumes appropriately for their environments. # Prerequisites @@ -214,6 +226,8 @@ Starting with Spark 2.4.0, users can mount the following types of Kubernetes [vo * [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. * [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. +**NB:** Please see the [Security](#security) section of this document for security issues related to volume mounts. + To mount a volume of any of the types above into the driver pod, use the following configuration property: ``` From 696b75a81013ad61d25e0552df2b019c7531f983 Mon Sep 17 00:00:00 2001 From: Matt Molek Date: Fri, 16 Nov 2018 10:00:21 -0600 Subject: [PATCH 058/145] [SPARK-25934][MESOS] Don't propagate SPARK_CONF_DIR from spark submit ## What changes were proposed in this pull request? Don't propagate SPARK_CONF_DIR to the driver in mesos cluster mode. ## How was this patch tested? I built the 2.3.2 tag with this patch added and deployed a test job to a mesos cluster to confirm that the incorrect SPARK_CONF_DIR was no longer passed from the submit command. Closes #22937 from mpmolek/fix-conf-dir. Authored-by: Matt Molek Signed-off-by: Sean Owen --- .../spark/deploy/rest/RestSubmissionClient.scala | 8 +++++--- .../deploy/rest/StandaloneRestSubmitSuite.scala | 12 ++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 31a8e3e60c067..afa413fe165df 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -408,6 +408,10 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { } private[spark] object RestSubmissionClient { + + // SPARK_HOME and SPARK_CONF_DIR are filtered out because they are usually wrong + // on the remote machine (SPARK-12345) (SPARK-25934) + private val BLACKLISTED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR") private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -417,9 +421,7 @@ private[spark] object RestSubmissionClient { */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { env.filterKeys { k => - // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || - k.startsWith("MESOS_") + (k.startsWith("SPARK_") && !BLACKLISTED_SPARK_ENV_VARS.contains(k)) || k.startsWith("MESOS_") } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 4839c842cc785..89b8bb4ff7d03 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -396,6 +396,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(filteredVariables == Map("SPARK_VAR" -> "1")) } + test("client does not send 'SPARK_HOME' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_HOME" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client does not send 'SPARK_CONF_DIR' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_CONF_DIR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + test("client includes mesos env vars") { val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) From a2fc48c28c06192d1f650582d128d60c7188ec62 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sat, 17 Nov 2018 00:12:17 +0800 Subject: [PATCH 059/145] [SPARK-26034][PYTHON][TESTS] Break large mllib/tests.py file into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR breaks down the large mllib/tests.py file that contains all Python MLlib unit tests into several smaller test files to be easier to read and maintain. The tests are broken down as follows: ``` pyspark ├── __init__.py ... ├── mllib │ ├── __init__.py ... │ ├── tests │ │ ├── __init__.py │ │ ├── test_algorithms.py │ │ ├── test_feature.py │ │ ├── test_linalg.py │ │ ├── test_stat.py │ │ ├── test_streaming_algorithms.py │ │ └── test_util.py ... ├── testing ... │ ├── mllibutils.py ... ``` ## How was this patch tested? Ran tests manually by module to ensure test count was the same, and ran `python/run-tests --modules=pyspark-mllib` to verify all passing with Python 2.7 and Python 3.6. Also installed scipy to include optional tests in test_linalg. Closes #23056 from BryanCutler/python-test-breakup-mllib-SPARK-26034. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 9 +- python/pyspark/mllib/tests.py | 1787 ----------------- python/pyspark/mllib/tests/__init__.py | 16 + python/pyspark/mllib/tests/test_algorithms.py | 313 +++ python/pyspark/mllib/tests/test_feature.py | 201 ++ python/pyspark/mllib/tests/test_linalg.py | 642 ++++++ python/pyspark/mllib/tests/test_stat.py | 197 ++ .../mllib/tests/test_streaming_algorithms.py | 523 +++++ python/pyspark/mllib/tests/test_util.py | 115 ++ python/pyspark/testing/mllibutils.py | 44 + 10 files changed, 2059 insertions(+), 1788 deletions(-) delete mode 100644 python/pyspark/mllib/tests.py create mode 100644 python/pyspark/mllib/tests/__init__.py create mode 100644 python/pyspark/mllib/tests/test_algorithms.py create mode 100644 python/pyspark/mllib/tests/test_feature.py create mode 100644 python/pyspark/mllib/tests/test_linalg.py create mode 100644 python/pyspark/mllib/tests/test_stat.py create mode 100644 python/pyspark/mllib/tests/test_streaming_algorithms.py create mode 100644 python/pyspark/mllib/tests/test_util.py create mode 100644 python/pyspark/testing/mllibutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 58b48f43f6468..547635a412913 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -416,6 +416,7 @@ def __hash__(self): "python/pyspark/mllib" ], python_test_goals=[ + # doctests "pyspark.mllib.classification", "pyspark.mllib.clustering", "pyspark.mllib.evaluation", @@ -430,7 +431,13 @@ def __hash__(self): "pyspark.mllib.stat.KernelDensity", "pyspark.mllib.tree", "pyspark.mllib.util", - "pyspark.mllib.tests", + # unittests + "pyspark.mllib.tests.test_algorithms", + "pyspark.mllib.tests.test_feature", + "pyspark.mllib.tests.test_linalg", + "pyspark.mllib.tests.test_stat", + "pyspark.mllib.tests.test_streaming_algorithms", + "pyspark.mllib.tests.test_util", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py deleted file mode 100644 index 4c2ce137e331c..0000000000000 --- a/python/pyspark/mllib/tests.py +++ /dev/null @@ -1,1787 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Fuller unit tests for Python MLlib. -""" - -import os -import sys -import tempfile -import array as pyarray -from math import sqrt -from time import time, sleep -from shutil import rmtree - -from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) -from numpy import sum as array_sum - -from py4j.protocol import Py4JJavaError -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version > '3': - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark import SparkContext -import pyspark.ml.linalg as newlinalg -from pyspark.mllib.common import _to_java_object_rdd -from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.linalg.distributed import RowMatrix -from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD -from pyspark.mllib.fpm import FPGrowth -from pyspark.mllib.recommendation import Rating -from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD -from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.stat import Statistics -from pyspark.mllib.feature import HashingTF -from pyspark.mllib.feature import Word2Vec -from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler, ElementwiseProduct -from pyspark.mllib.util import LinearDataGenerator -from pyspark.mllib.util import MLUtils -from pyspark.serializers import PickleSerializer -from pyspark.streaming import StreamingContext -from pyspark.sql import SparkSession -from pyspark.sql.utils import IllegalArgumentException -from pyspark.streaming import StreamingContext - -_have_scipy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class MLLibStreamingTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.ssc = StreamingContext(self.sc, 1.0) - - def tearDown(self): - self.ssc.stop(False) - self.sc.stop() - - @staticmethod - def _eventually(condition, timeout=30.0, catch_assertions=False): - """ - Wait a given amount of time for a condition to pass, else fail with an error. - This is a helper utility for streaming ML tests. - :param condition: Function that checks for termination conditions. - condition() can return: - - True: Conditions met. Return without error. - - other value: Conditions not met yet. Continue. Upon timeout, - include last such value in error message. - Note that this method may be called at any time during - streaming execution (e.g., even before any results - have been created). - :param timeout: Number of seconds to wait. Default 30 seconds. - :param catch_assertions: If False (default), do not catch AssertionErrors. - If True, catch AssertionErrors; continue, but save - error to throw upon timeout. - """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue)) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_parse_vector(self): - a = DenseVector([]) - self.assertEqual(str(a), '[]') - self.assertEqual(Vectors.parse(str(a)), a) - a = DenseVector([3, 4, 6, 7]) - self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(4, [], []) - self.assertEqual(str(a), '(4,[],[])') - self.assertEqual(SparseVector.parse(str(a)), a) - a = SparseVector(4, [0, 2], [3, 4]) - self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(10, [0, 1], [4, 5]) - self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - def test_ml_mllib_vector_conversion(self): - # to ml - # dense - mllibDV = Vectors.dense([1, 2, 3]) - mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) - mlDV2 = mllibDV.asML() - self.assertEqual(mlDV2, mlDV1) - # sparse - mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV2 = mllibSV.asML() - self.assertEqual(mlSV2, mlSV1) - # from ml - # dense - mllibDV1 = Vectors.dense([1, 2, 3]) - mlDV = newlinalg.Vectors.dense([1, 2, 3]) - mllibDV2 = Vectors.fromML(mlDV) - self.assertEqual(mllibDV1, mllibDV2) - # sparse - mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mllibSV2 = Vectors.fromML(mlSV) - self.assertEqual(mllibSV1, mllibSV2) - - def test_ml_mllib_matrix_conversion(self): - # to ml - # dense - mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM2 = mllibDM.asML() - self.assertEqual(mlDM2, mlDM1) - # transposed - mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt2 = mllibDMt.asML() - self.assertEqual(mlDMt2, mlDMt1) - # sparse - mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM2 = mllibSM.asML() - self.assertEqual(mlSM2, mlSM1) - # transposed - mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt2 = mllibSMt.asML() - self.assertEqual(mlSMt2, mlSMt1) - # from ml - # dense - mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) - mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) - mllibDM2 = Matrices.fromML(mlDM) - self.assertEqual(mllibDM1, mllibDM2) - # transposed - mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) - mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) - mllibDMt2 = Matrices.fromML(mlDMt) - self.assertEqual(mllibDMt1, mllibDMt2) - # sparse - mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mllibSM2 = Matrices.fromML(mlSM) - self.assertEqual(mllibSM1, mllibSM2) - # transposed - mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mllibSMt2 = Matrices.fromML(mlSMt) - self.assertEqual(mllibSMt1, mllibSMt2) - - -class ListTests(MLlibTestCase): - - """ - Test MLlib algorithms on plain lists, to make sure they're passed through - as NumPy arrays. - """ - - def test_bisecting_kmeans(self): - from pyspark.mllib.clustering import BisectingKMeans - data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) - bskm = BisectingKMeans() - model = bskm.train(self.sc.parallelize(data, 2), k=4) - p = array([0.0, 0.0]) - rdd_p = self.sc.parallelize([p]) - self.assertEqual(model.predict(p), model.predict(rdd_p).first()) - self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) - self.assertEqual(model.k, len(model.clusterCenters)) - - def test_kmeans(self): - from pyspark.mllib.clustering import KMeans - data = [ - [0, 1.1], - [0, 1.2], - [1.1, 0], - [1.2, 0], - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", - initializationSteps=7, epsilon=1e-4) - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_kmeans_deterministic(self): - from pyspark.mllib.clustering import KMeans - X = range(0, 100, 10) - Y = range(0, 100, 10) - data = [[x, y] for x, y in zip(X, Y)] - clusters1 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - clusters2 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - centers1 = clusters1.centers - centers2 = clusters2.centers - for c1, c2 in zip(centers1, centers2): - # TODO: Allow small numeric difference. - self.assertTrue(array_equal(c1, c2)) - - def test_gmm(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - [1, 2], - [8, 9], - [-4, -3], - [-6, -7], - ]) - clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=1) - labels = clusters.predict(data).collect() - self.assertEqual(labels[0], labels[1]) - self.assertEqual(labels[2], labels[3]) - - def test_gmm_deterministic(self): - from pyspark.mllib.clustering import GaussianMixture - x = range(0, 100, 10) - y = range(0, 100, 10) - data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) - clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEqual(round(c1, 7), round(c2, 7)) - - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ - RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel - data = [ - LabeledPoint(0.0, [1, 0, 0]), - LabeledPoint(1.0, [0, 1, 1]), - LabeledPoint(0.0, [2, 0, 0]), - LabeledPoint(1.0, [0, 2, 1]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - temp_dir = tempfile.mkdtemp() - - lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd, iterations=10) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - dt_model_dir = os.path.join(temp_dir, "dt") - dt_model.save(self.sc, dt_model_dir) - same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) - self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) - - rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, - maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - rf_model_dir = os.path.join(temp_dir, "rf") - rf_model.save(self.sc, rf_model_dir) - same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) - self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) - - gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - gbt_model_dir = os.path.join(temp_dir, "gbt") - gbt_model.save(self.sc, gbt_model_dir) - same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) - self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) - - try: - rmtree(temp_dir) - except OSError: - pass - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees - data = [ - LabeledPoint(-1.0, [0, -1]), - LabeledPoint(1.0, [0, 1]), - LabeledPoint(-1.0, [0, -2]), - LabeledPoint(1.0, [0, 2]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd, iterations=10) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - except ValueError: - self.fail() - - # Verify that maxBins is being passed through - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) - with self.assertRaises(Exception) as cm: - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) - - -class StatTests(MLlibTestCase): - # SPARK-4023 - def test_col_with_different_rdds(self): - # numpy - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(1000, summary.count()) - # array - data = self.sc.parallelize([range(10)] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - # array - data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - - def test_col_norms(self): - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(10, len(summary.normL1())) - self.assertEqual(10, len(summary.normL2())) - - data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) - summary2 = Statistics.colStats(data2) - self.assertEqual(array([45.0]), summary2.normL1()) - import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) - self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(MLlibTestCase): - - """ - Test both vector operations and MLlib algorithms with SciPy sparse matrices, - if SciPy is available. - """ - - def test_serialize(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(lil)) - self.assertEqual(sv, _convert_to_vector(lil.tocsc())) - self.assertEqual(sv, _convert_to_vector(lil.tocoo())) - self.assertEqual(sv, _convert_to_vector(lil.tocsr())) - self.assertEqual(sv, _convert_to_vector(lil.todok())) - - def serialize(l): - return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEqual(sv, serialize(lil)) - self.assertEqual(sv, serialize(lil.tocsc())) - self.assertEqual(sv, serialize(lil.tocsr())) - self.assertEqual(sv, serialize(lil.todok())) - - def test_convert_to_vector(self): - from scipy.sparse import csc_matrix - # Create a CSC matrix with non-sorted indices - indptr = array([0, 2]) - indices = array([3, 1]) - data = array([2.0, 1.0]) - csc = csc_matrix((data, indices, indptr)) - self.assertFalse(csc.has_sorted_indices) - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(csc)) - - def test_dot(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEqual(10.0, dv.dot(lil)) - - def test_squared_distance(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 3 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEqual(15.0, dv.squared_distance(lil)) - self.assertEqual(15.0, sv.squared_distance(lil)) - - def scipy_matrix(self, size, values): - """Create a column SciPy matrix from a dictionary of values""" - from scipy.sparse import lil_matrix - lil = lil_matrix((size, 1)) - for key, value in values.items(): - lil[key, 0] = value - return lil - - def test_clustering(self): - from pyspark.mllib.clustering import KMeans - data = [ - self.scipy_matrix(3, {1: 1.0}), - self.scipy_matrix(3, {1: 1.1}), - self.scipy_matrix(3, {2: 1.0}), - self.scipy_matrix(3, {2: 1.1}) - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LogisticRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - -class ChiSqTestTests(MLlibTestCase): - def test_goodness_of_fit(self): - from numpy import inf - - observed = Vectors.dense([4, 6, 5]) - pearson = Statistics.chiSqTest(observed) - - # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` - self.assertEqual(pearson.statistic, 0.4) - self.assertEqual(pearson.degreesOfFreedom, 2) - self.assertAlmostEqual(pearson.pValue, 0.8187, 4) - - # Different expected and observed sum - observed1 = Vectors.dense([21, 38, 43, 80]) - expected1 = Vectors.dense([3, 5, 7, 20]) - pearson1 = Statistics.chiSqTest(observed1, expected1) - - # Results validated against the R command - # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` - self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) - self.assertEqual(pearson1.degreesOfFreedom, 3) - self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) - - # Vectors with different sizes - observed3 = Vectors.dense([1.0, 2.0, 3.0]) - expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) - self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) - - # Negative counts in observed - neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) - - # Count = 0.0 in expected but not observed - zero_expected = Vectors.dense([1.0, 0.0, 3.0]) - pearson_inf = Statistics.chiSqTest(observed, zero_expected) - self.assertEqual(pearson_inf.statistic, inf) - self.assertEqual(pearson_inf.degreesOfFreedom, 2) - self.assertEqual(pearson_inf.pValue, 0.0) - - # 0.0 in expected and observed simultaneously - zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises( - IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) - - def test_matrix_independence(self): - data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] - chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - - # Results validated against R command - # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` - self.assertAlmostEqual(chi.statistic, 21.9958, 4) - self.assertEqual(chi.degreesOfFreedom, 6) - self.assertAlmostEqual(chi.pValue, 0.001213, 4) - - # Negative counts - neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) - - # Row sum = 0.0 - row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) - - # Column sum = 0.0 - col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) - - def test_chi_sq_pearson(self): - data = [ - LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), - LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), - LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), - LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) - ] - - for numParts in [2, 4, 6, 8]: - chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) - feature1 = chi[0] - self.assertEqual(feature1.statistic, 0.75) - self.assertEqual(feature1.degreesOfFreedom, 2) - self.assertAlmostEqual(feature1.pValue, 0.6873, 4) - - feature2 = chi[1] - self.assertEqual(feature2.statistic, 1.5) - self.assertEqual(feature2.degreesOfFreedom, 3) - self.assertAlmostEqual(feature2.pValue, 0.6823, 4) - - def test_right_number_of_results(self): - num_cols = 1001 - sparse_data = [ - LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), - LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) - ] - chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) - self.assertEqual(len(chi), num_cols) - self.assertIsNotNone(chi[1000]) - - -class KolmogorovSmirnovTest(MLlibTestCase): - - def test_R_implementation_equivalence(self): - data = self.sc.parallelize([ - 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, - -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, - -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, - -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, - 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ]) - model = Statistics.kolmogorovSmirnovTest(data, "norm") - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - -class SerDeTest(MLlibTestCase): - def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) - self.assertEqual(_to_java_object_rdd(data).count(), 10) - - -class FeatureTest(MLlibTestCase): - def test_idf_model(self): - data = [ - Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), - Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), - Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), - Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) - ] - model = IDF().fit(self.sc.parallelize(data, 2)) - idf = model.idf() - self.assertEqual(len(idf), 11) - - -class Word2VecTests(MLlibTestCase): - def test_word2vec_setters(self): - model = Word2Vec() \ - .setVectorSize(2) \ - .setLearningRate(0.01) \ - .setNumPartitions(2) \ - .setNumIterations(10) \ - .setSeed(1024) \ - .setMinCount(3) \ - .setWindowSize(6) - self.assertEqual(model.vectorSize, 2) - self.assertTrue(model.learningRate < 0.02) - self.assertEqual(model.numPartitions, 2) - self.assertEqual(model.numIterations, 10) - self.assertEqual(model.seed, 1024) - self.assertEqual(model.minCount, 3) - self.assertEqual(model.windowSize, 6) - - def test_word2vec_get_vectors(self): - data = [ - ["a", "b", "c", "d", "e", "f", "g"], - ["a", "b", "c", "d", "e", "f"], - ["a", "b", "c", "d", "e"], - ["a", "b", "c", "d"], - ["a", "b", "c"], - ["a", "b"], - ["a"] - ] - model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEqual(len(model.getVectors()), 3) - - -class StandardScalerTests(MLlibTestCase): - def test_model_setters(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertIsNotNone(model.setWithMean(True)) - self.assertIsNotNone(model.setWithStd(True)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) - - def test_model_transform(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) - - -class ElementwiseProductTests(MLlibTestCase): - def test_model_transform(self): - weight = Vectors.dense([3, 2, 1]) - - densevec = Vectors.dense([4, 5, 6]) - sparsevec = Vectors.sparse(3, [0], [1]) - eprod = ElementwiseProduct(weight) - self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) - self.assertEqual( - eprod.transform(sparsevec), SparseVector(3, [0], [3])) - - -class StreamingKMeansTest(MLLibStreamingTestCase): - def test_model_params(self): - """Test that the model params are set correctly""" - stkm = StreamingKMeans() - stkm.setK(5).setDecayFactor(0.0) - self.assertEqual(stkm._k, 5) - self.assertEqual(stkm._decayFactor, 0.0) - - # Model not set yet. - self.assertIsNone(stkm.latestModel()) - self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) - - stkm.setInitialCenters( - centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEqual( - stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) - - def test_accuracy_for_single_center(self): - """Test that parameters obtained are correct for a single center.""" - centers, batches = self.streamingKMeansDataGenerator( - batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) - stkm = StreamingKMeans(1) - stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) - input_stream = self.ssc.queueStream( - [self.sc.parallelize(batch, 1) for batch in batches]) - stkm.trainOn(input_stream) - - self.ssc.start() - - def condition(): - self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) - return True - self._eventually(condition, catch_assertions=True) - - realCenters = array_sum(array(centers), axis=0) - for i in range(5): - modelCenters = stkm.latestModel().centers[0][i] - self.assertAlmostEqual(centers[0][i], modelCenters, 1) - self.assertAlmostEqual(realCenters[i], modelCenters, 1) - - def streamingKMeansDataGenerator(self, batches, numPoints, - k, d, r, seed, centers=None): - rng = random.RandomState(seed) - - # Generate centers. - centers = [rng.randn(d) for i in range(k)] - - return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) - for j in range(numPoints)] - for i in range(batches)] - - def test_trainOn_model(self): - """Test the model on toy data with four clusters.""" - stkm = StreamingKMeans() - initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] - stkm.setInitialCenters( - centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - - # Create a toy dataset by setting a tiny offset for each point. - offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] - batches = [] - for offset in offsets: - batches.append([[offset[0] + center[0], offset[1] + center[1]] - for center in initCenters]) - - batches = [self.sc.parallelize(batch, 1) for batch in batches] - input_stream = self.ssc.queueStream(batches) - stkm.trainOn(input_stream) - self.ssc.start() - - # Give enough time to train the model. - def condition(): - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) - return True - self._eventually(condition, catch_assertions=True) - - def test_predictOn_model(self): - """Test that the model predicts correctly on toy data.""" - stkm = StreamingKMeans() - stkm._model = StreamingKMeansModel( - clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], - clusterWeights=[1.0, 1.0, 1.0, 1.0]) - - predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] - predict_stream = self.ssc.queueStream(predict_data) - predict_val = stkm.predictOn(predict_stream) - - result = [] - - def update(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - result.append(rdd_collect) - - predict_val.foreachRDD(update) - self.ssc.start() - - def condition(): - self.assertEqual(result, [[0], [1], [2], [3]]) - return True - - self._eventually(condition, catch_assertions=True) - - @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") - def test_trainOn_predictOn(self): - """Test that prediction happens on the updated model.""" - stkm = StreamingKMeans(decayFactor=0.0, k=2) - stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) - - # Since decay factor is set to zero, once the first batch - # is passed the clusterCenters are updated to [-0.5, 0.7] - # which causes 0.2 & 0.3 to be classified as 1, even though the - # classification based in the initial model would have been 0 - # proving that the model is updated. - batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [self.sc.parallelize(batch) for batch in batches] - input_stream = self.ssc.queueStream(batches) - predict_results = [] - - def collect(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - predict_results.append(rdd_collect) - - stkm.trainOn(input_stream) - predict_stream = stkm.predictOn(input_stream) - predict_stream.foreachRDD(collect) - - self.ssc.start() - - def condition(): - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) - return True - - self._eventually(condition, catch_assertions=True) - - -class LinearDataGeneratorTests(MLlibTestCase): - def test_dim(self): - linear_data = LinearDataGenerator.generateLinearInput( - intercept=0.0, weights=[0.0, 0.0, 0.0], - xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], - nPoints=4, seed=0, eps=0.1) - self.assertEqual(len(linear_data), 4) - for point in linear_data: - self.assertEqual(len(point.features), 3) - - linear_data = LinearDataGenerator.generateLinearRDD( - sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, - nParts=2, intercept=0.0).collect() - self.assertEqual(len(linear_data), 6) - for point in linear_data: - self.assertEqual(len(point.features), 2) - - -class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): - - @staticmethod - def generateLogisticInput(offset, scale, nPoints, seed): - """ - Generate 1 / (1 + exp(-x * scale + offset)) - - where, - x is randomnly distributed and the threshold - and labels for each sample in x is obtained from a random uniform - distribution. - """ - rng = random.RandomState(seed) - x = rng.randn(nPoints) - sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) - y_p = rng.rand(nPoints) - cut_off = y_p <= sigmoid - y_p[cut_off] = 1.0 - y_p[~cut_off] = 0.0 - return [ - LabeledPoint(y_p[i], Vectors.dense([x[i]])) - for i in range(nPoints)] - - def test_parameter_accuracy(self): - """ - Test that the final value of weights is close to the desired value. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - - self.ssc.start() - - def condition(): - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_convergence(self): - """ - Test that weights converge to the required value on toy data. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - models = [] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - input_stream.foreachRDD( - lambda x: models.append(slr.latestModel().weights[0])) - - self.ssc.start() - - def condition(): - self.assertEqual(len(models), len(input_batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, 60.0, catch_assertions=True) - - t_models = array(models) - diff = t_models[1:] - t_models[:-1] - # Test that weights improve with a small tolerance - self.assertTrue(all(diff >= -0.1)) - self.assertTrue(array_sum(diff > 0) > 1) - - @staticmethod - def calculate_accuracy_error(true, predicted): - return sum(abs(array(true) - array(predicted))) / len(true) - - def test_predictions(self): - """Test predicted values on a toy model.""" - input_batches = [] - for i in range(20): - batch = self.sc.parallelize( - self.generateLogisticInput(0, 1.5, 100, 42 + i)) - input_batches.append(batch.map(lambda x: (x.label, x.features))) - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([1.5]) - predict_stream = slr.predictOnValues(input_stream) - true_predicted = [] - predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - self.ssc.start() - - def condition(): - self.assertEqual(len(true_predicted), len(input_batches)) - return True - - self._eventually(condition, catch_assertions=True) - - # Test that the accuracy error is no more than 0.4 on each batch. - for batch in true_predicted: - true, predicted = zip(*batch) - self.assertTrue( - self.calculate_accuracy_error(true, predicted) < 0.4) - - def test_training_and_prediction(self): - """Test that the model improves on toy data with no. of batches""" - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.01, numIterations=25) - slr.setInitialWeights([-0.1]) - errors = [] - - def collect_errors(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(self.calculate_accuracy_error(true, predicted)) - - true_predicted = [] - input_stream = self.ssc.queueStream(input_batches) - predict_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - ps = slr.predictOnValues(predict_stream) - ps.foreachRDD(lambda x: collect_errors(x)) - - self.ssc.start() - - def condition(): - # Test that the improvement in error is > 0.3 - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 0.3) - if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): - - def assertArrayAlmostEqual(self, array1, array2, dec): - for i, j in array1, array2: - self.assertAlmostEqual(i, j, dec) - - def test_parameter_accuracy(self): - """Test that coefs are predicted accurately by fitting on toy data.""" - - # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients - # (10, 10) - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0, 0.0]) - xMean = [0.0, 0.0] - xVariance = [1.0 / 3.0, 1.0 / 3.0] - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - input_stream = self.ssc.queueStream(batches) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_parameter_convergence(self): - """Test that the model parameters improve with streaming data.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - model_weights = [] - input_stream = self.ssc.queueStream(batches) - input_stream.foreachRDD( - lambda x: model_weights.append(slr.latestModel().weights[0])) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertEqual(len(model_weights), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - w = array(model_weights) - diff = w[1:] - w[:-1] - self.assertTrue(all(diff >= -0.1)) - - def test_prediction(self): - """Test prediction on a model with weights already set.""" - # Create a model with initial Weights equal to coefs - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([10.0, 10.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], - 100, 42 + i, 0.1) - batches.append( - self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) - - input_stream = self.ssc.queueStream(batches) - output_stream = slr.predictOnValues(input_stream) - samples = [] - output_stream.foreachRDD(lambda x: samples.append(x.collect())) - - self.ssc.start() - - def condition(): - self.assertEqual(len(samples), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - # Test that mean absolute error on each batch is less than 0.1 - for batch in samples: - true, predicted = zip(*batch) - self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) - - def test_train_prediction(self): - """Test that error on test data improves as model is trained.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in batches] - errors = [] - - def func(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(mean(abs(true) - abs(predicted))) - - input_stream = self.ssc.queueStream(batches) - output_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - output_stream = slr.predictOnValues(output_stream) - output_stream.foreachRDD(func) - self.ssc.start() - - def condition(): - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 2) - if len(errors) >= 3 and errors[1] - errors[-1] > 2: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class MLUtilsTests(MLlibTestCase): - def test_append_bias(self): - data = [2.0, 2.0, 2.0] - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_vector(self): - data = Vectors.dense([2.0, 2.0, 2.0]) - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_sp_vector(self): - data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) - expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) - # Returned value must be SparseVector - ret = MLUtils.appendBias(data) - self.assertEqual(ret, expected) - self.assertEqual(type(ret), SparseVector) - - def test_load_vectors(self): - import shutil - data = [ - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0] - ] - temp_dir = tempfile.mkdtemp() - load_vectors_path = os.path.join(temp_dir, "test_load_vectors") - try: - self.sc.parallelize(data).saveAsTextFile(load_vectors_path) - ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) - ret = ret_rdd.collect() - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) - self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) - except: - self.fail() - finally: - shutil.rmtree(load_vectors_path) - - -class ALSTests(MLlibTestCase): - - def test_als_ratings_serialize(self): - r = Rating(7, 1123, 3.14) - jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) - nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) - self.assertEqual(r.user, nr.user) - self.assertEqual(r.product, nr.product) - self.assertAlmostEqual(r.rating, nr.rating, 2) - - def test_als_ratings_id_long_error(self): - r = Rating(1205640308657491975, 50233468418, 1.0) - # rating user id exceeds max int value, should fail when pickled - self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, - bytearray(ser.dumps(r))) - - -class HashingTFTest(MLlibTestCase): - - def test_binary_term_freqs(self): - hashingTF = HashingTF(100).setBinary(True) - doc = "a a b c c c".split(" ") - n = hashingTF.numFeatures - output = hashingTF.transform(doc).toArray() - expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, - hashingTF.indexOf("b"): 1.0, - hashingTF.indexOf("c"): 1.0}).toArray() - for i in range(0, n): - self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(output[i])) - - -class DimensionalityReductionTests(MLlibTestCase): - - denseData = [ - Vectors.dense([0.0, 1.0, 2.0]), - Vectors.dense([3.0, 4.0, 5.0]), - Vectors.dense([6.0, 7.0, 8.0]), - Vectors.dense([9.0, 0.0, 1.0]) - ] - sparseData = [ - Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), - Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), - Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), - Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) - ] - - def assertEqualUpToSign(self, vecA, vecB): - eq1 = vecA - vecB - eq2 = vecA + vecB - self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) - - def test_svd(self): - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - m = 4 - n = 3 - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - rm = mat.computeSVD(k, computeU=True) - self.assertEqual(rm.s.size, k) - self.assertEqual(rm.U.numRows(), m) - self.assertEqual(rm.U.numCols(), k) - self.assertEqual(rm.V.numRows, n) - self.assertEqual(rm.V.numCols, k) - - # Test that U returned is None if computeU is set to False. - self.assertEqual(mat.computeSVD(1).U, None) - - # Test that low rank matrices cannot have number of singular values - # greater than a limit. - rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) - self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) - - def test_pca(self): - expected_pcs = array([ - [0.0, 1.0, 0.0], - [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], - [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] - ]) - n = 3 - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - pcs = mat.computePrincipalComponents(k) - self.assertEqual(pcs.numRows, n) - self.assertEqual(pcs.numCols, k) - - # We can just test the updated principal component for equality. - self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) - - -class FPGrowthTest(MLlibTestCase): - - def test_fpgrowth(self): - data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] - rdd = self.sc.parallelize(data, 2) - model1 = FPGrowth.train(rdd, 0.6, 2) - # use default data partition number when numPartitions is not specified - model2 = FPGrowth.train(rdd, 0.6) - self.assertEqual(sorted(model1.freqItemsets().collect()), - sorted(model2.freqItemsets().collect())) - -if __name__ == "__main__": - from pyspark.mllib.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - sc.stop() diff --git a/python/pyspark/mllib/tests/__init__.py b/python/pyspark/mllib/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/mllib/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py new file mode 100644 index 0000000000000..8a3454144a115 --- /dev/null +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -0,0 +1,313 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +from shutil import rmtree + +from numpy import array, array_equal + +from py4j.protocol import Py4JJavaError + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.fpm import FPGrowth +from pyspark.mllib.recommendation import Rating +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + + +ser = make_serializer() + + +class ListTests(MLlibTestCase): + + """ + Test MLlib algorithms on plain lists, to make sure they're passed through + as NumPy arrays. + """ + + def test_bisecting_kmeans(self): + from pyspark.mllib.clustering import BisectingKMeans + data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) + bskm = BisectingKMeans() + model = bskm.train(self.sc.parallelize(data, 2), k=4) + p = array([0.0, 0.0]) + rdd_p = self.sc.parallelize([p]) + self.assertEqual(model.predict(p), model.predict(rdd_p).first()) + self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) + self.assertEqual(model.k, len(model.clusterCenters)) + + def test_kmeans(self): + from pyspark.mllib.clustering import KMeans + data = [ + [0, 1.1], + [0, 1.2], + [1.1, 0], + [1.2, 0], + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", + initializationSteps=7, epsilon=1e-4) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=1) + labels = clusters.predict(data).collect() + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEqual(round(c1, 7), round(c2, 7)) + + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest, \ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel + data = [ + LabeledPoint(0.0, [1, 0, 0]), + LabeledPoint(1.0, [0, 1, 1]), + LabeledPoint(0.0, [2, 0, 0]), + LabeledPoint(1.0, [0, 2, 1]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + temp_dir = tempfile.mkdtemp() + + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd, iterations=10) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + rmtree(temp_dir) + except OSError: + pass + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees + data = [ + LabeledPoint(-1.0, [0, -1]), + LabeledPoint(1.0, [0, 1]), + LabeledPoint(-1.0, [0, -2]), + LabeledPoint(1.0, [0, 2]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd, iterations=10) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + except ValueError: + self.fail() + + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + + +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, + bytearray(ser.dumps(r))) + + +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py new file mode 100644 index 0000000000000..48ed810fa6fcb --- /dev/null +++ b/python/pyspark/mllib/tests/test_feature.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from math import sqrt + +from numpy import array, random, exp, abs, tile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec +from pyspark.testing.mllibutils import MLlibTestCase + + +class FeatureTest(MLlibTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(MLlibTestCase): + def test_word2vec_setters(self): + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) \ + .setWindowSize(6) + self.assertEqual(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) + self.assertEqual(model.windowSize, 6) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEqual(len(model.getVectors()), 3) + + +class StandardScalerTests(MLlibTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py new file mode 100644 index 0000000000000..550e32a9af024 --- /dev/null +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import array as pyarray + +from numpy import array, array_equal, zeros, arange, tile, ones, inf +from numpy import sum as array_sum + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import pyspark.ml.linalg as newlinalg +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.regression import LabeledPoint +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + +_have_scipy = False +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_parse_vector(self): + a = DenseVector([]) + self.assertEqual(str(a), '[]') + self.assertEqual(Vectors.parse(str(a)), a) + a = DenseVector([3, 4, 6, 7]) + self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(4, [], []) + self.assertEqual(str(a), '(4,[],[])') + self.assertEqual(SparseVector.parse(str(a)), a) + a = SparseVector(4, [0, 2], [3, 4]) + self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(10, [0, 1], [4, 5]) + self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + def test_ml_mllib_vector_conversion(self): + # to ml + # dense + mllibDV = Vectors.dense([1, 2, 3]) + mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) + mlDV2 = mllibDV.asML() + self.assertEqual(mlDV2, mlDV1) + # sparse + mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV2 = mllibSV.asML() + self.assertEqual(mlSV2, mlSV1) + # from ml + # dense + mllibDV1 = Vectors.dense([1, 2, 3]) + mlDV = newlinalg.Vectors.dense([1, 2, 3]) + mllibDV2 = Vectors.fromML(mlDV) + self.assertEqual(mllibDV1, mllibDV2) + # sparse + mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mllibSV2 = Vectors.fromML(mlSV) + self.assertEqual(mllibSV1, mllibSV2) + + def test_ml_mllib_matrix_conversion(self): + # to ml + # dense + mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM2 = mllibDM.asML() + self.assertEqual(mlDM2, mlDM1) + # transposed + mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt2 = mllibDMt.asML() + self.assertEqual(mlDMt2, mlDMt1) + # sparse + mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM2 = mllibSM.asML() + self.assertEqual(mlSM2, mlSM1) + # transposed + mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt2 = mllibSMt.asML() + self.assertEqual(mlSMt2, mlSMt1) + # from ml + # dense + mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) + mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) + mllibDM2 = Matrices.fromML(mlDM) + self.assertEqual(mllibDM1, mllibDM2) + # transposed + mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) + mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) + mllibDMt2 = Matrices.fromML(mlDMt) + self.assertEqual(mllibDMt1, mllibDMt2) + # sparse + mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mllibSM2 = Matrices.fromML(mlSM) + self.assertEqual(mllibSM1, mllibSM2) + # transposed + mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mllibSMt2 = Matrices.fromML(mlSMt) + self.assertEqual(mllibSMt1, mllibSMt2) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +@unittest.skipIf(not _have_scipy, "SciPy not installed") +class SciPyTests(MLlibTestCase): + + """ + Test both vector operations and MLlib algorithms with SciPy sparse matrices, + if SciPy is available. + """ + + def test_serialize(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) + + def serialize(l): + return ser.loads(ser.dumps(_convert_to_vector(l))) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) + + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + + def test_dot(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + self.assertEqual(10.0, dv.dot(lil)) + + def test_squared_distance(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 3 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) + + def scipy_matrix(self, size, values): + """Create a column SciPy matrix from a dictionary of values""" + from scipy.sparse import lil_matrix + lil = lil_matrix((size, 1)) + for key, value in values.items(): + lil[key, 0] = value + return lil + + def test_clustering(self): + from pyspark.mllib.clustering import KMeans + data = [ + self.scipy_matrix(3, {1: 1.0}), + self.scipy_matrix(3, {1: 1.1}), + self.scipy_matrix(3, {2: 1.0}), + self.scipy_matrix(3, {2: 1.1}) + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LogisticRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_linalg import * + if not _have_scipy: + print("NOTE: Skipping SciPy tests as it does not seem to be installed") + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) + if not _have_scipy: + print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py new file mode 100644 index 0000000000000..5e74087d8fa7b --- /dev/null +++ b/python/pyspark/mllib/tests/test_stat.py @@ -0,0 +1,197 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import array as pyarray + +from numpy import array + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.mllibutils import MLlibTestCase + + +class StatTests(MLlibTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + + +class ChiSqTestTests(MLlibTestCase): + def test_goodness_of_fit(self): + from numpy import inf + + observed = Vectors.dense([4, 6, 5]) + pearson = Statistics.chiSqTest(observed) + + # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + self.assertEqual(pearson.statistic, 0.4) + self.assertEqual(pearson.degreesOfFreedom, 2) + self.assertAlmostEqual(pearson.pValue, 0.8187, 4) + + # Different expected and observed sum + observed1 = Vectors.dense([21, 38, 43, 80]) + expected1 = Vectors.dense([3, 5, 7, 20]) + pearson1 = Statistics.chiSqTest(observed1, expected1) + + # Results validated against the R command + # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) + self.assertEqual(pearson1.degreesOfFreedom, 3) + self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) + + # Vectors with different sizes + observed3 = Vectors.dense([1.0, 2.0, 3.0]) + expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) + self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) + + # Negative counts in observed + neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) + + # Count = 0.0 in expected but not observed + zero_expected = Vectors.dense([1.0, 0.0, 3.0]) + pearson_inf = Statistics.chiSqTest(observed, zero_expected) + self.assertEqual(pearson_inf.statistic, inf) + self.assertEqual(pearson_inf.degreesOfFreedom, 2) + self.assertEqual(pearson_inf.pValue, 0.0) + + # 0.0 in expected and observed simultaneously + zero_observed = Vectors.dense([2.0, 0.0, 1.0]) + self.assertRaises( + IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) + + def test_matrix_independence(self): + data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + + # Results validated against R command + # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + self.assertAlmostEqual(chi.statistic, 21.9958, 4) + self.assertEqual(chi.degreesOfFreedom, 6) + self.assertAlmostEqual(chi.pValue, 0.001213, 4) + + # Negative counts + neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) + + # Row sum = 0.0 + row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) + + # Column sum = 0.0 + col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) + + def test_chi_sq_pearson(self): + data = [ + LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) + ] + + for numParts in [2, 4, 6, 8]: + chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) + feature1 = chi[0] + self.assertEqual(feature1.statistic, 0.75) + self.assertEqual(feature1.degreesOfFreedom, 2) + self.assertAlmostEqual(feature1.pValue, 0.6873, 4) + + feature2 = chi[1] + self.assertEqual(feature2.statistic, 1.5) + self.assertEqual(feature2.degreesOfFreedom, 3) + self.assertAlmostEqual(feature2.pValue, 0.6823, 4) + + def test_right_number_of_results(self): + num_cols = 1001 + sparse_data = [ + LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), + LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) + ] + chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) + self.assertEqual(len(chi), num_cols) + self.assertIsNotNone(chi[1000]) + + +class KolmogorovSmirnovTest(MLlibTestCase): + + def test_R_implementation_equivalence(self): + data = self.sc.parallelize([ + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ]) + model = Statistics.kolmogorovSmirnovTest(data, "norm") + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py new file mode 100644 index 0000000000000..ba95855fd4f00 --- /dev/null +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -0,0 +1,523 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from time import time, sleep + +from numpy import array, random, exp, dot, all, mean, abs +from numpy import sum as array_sum + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark import SparkContext +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD +from pyspark.mllib.util import LinearDataGenerator +from pyspark.streaming import StreamingContext + + +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + self.sc.stop() + + @staticmethod + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return + sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEqual( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + self.ssc.start() + + def condition(): + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offset for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + self.ssc.start() + + # Give enough time to train the model. + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + self.ssc.start() + + def condition(): + self.assertEqual(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) + + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [self.sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + self.ssc.start() + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + self.ssc.start() + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + self.ssc.start() + + def condition(): + self.assertEqual(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + # Test that weights improve with a small tolerance + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + self.ssc.start() + + def condition(): + self.assertEqual(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + self.ssc.start() + + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertEqual(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + + def condition(): + self.assertEqual(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(mean(abs(true) - abs(predicted))) + + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_streaming_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py new file mode 100644 index 0000000000000..c924eba80484c --- /dev/null +++ b/python/pyspark/mllib/tests/test_util.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import SparseVector, DenseVector, SparseMatrix, Vectors +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.testing.mllibutils import MLlibTestCase + + +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class SerDeTest(MLlibTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py new file mode 100644 index 0000000000000..9248182658f84 --- /dev/null +++ b/python/pyspark/testing/mllibutils.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark import SparkContext +from pyspark.serializers import PickleSerializer +from pyspark.sql import SparkSession + + +def make_serializer(): + return PickleSerializer() + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) + + def tearDown(self): + self.spark.stop() From 99cbc51b3250c07a3e8cc95c9b74e9d1725bac77 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 16 Nov 2018 09:51:41 -0800 Subject: [PATCH 060/145] [SPARK-26069][TESTS] Fix flaky test: RpcIntegrationSuite.sendRpcWithStreamFailures ## What changes were proposed in this pull request? The test failure is because `assertErrorAndClosed` misses one possible error message: `java.nio.channels.ClosedChannelException`. This happens when the second `uploadStream` is called after the channel has been closed. This can be reproduced by adding `Thread.sleep(1000)` below this line: https://github.com/apache/spark/blob/03306a6df39c9fd6cb581401c13c4dfc6bbd632e/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java#L217 This PR fixes the above issue and also improves the test failure messages of `assertErrorAndClosed`. ## How was this patch tested? Jenkins Closes #23041 from zsxwing/SPARK-26069. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../spark/network/RpcIntegrationSuite.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 1f4d75c7e2ec5..45f4a1808562d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -371,7 +371,10 @@ private void assertErrorsContain(Set errors, Set contains) { private void assertErrorAndClosed(RpcResult result, String expectedError) { assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); - // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + // we expect 1 additional error, which should contain one of the follow messages: + // - "closed" + // - "Connection reset" + // - "java.nio.channels.ClosedChannelException" Set errors = result.errorMessages; assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + errors, 2, errors.size()); @@ -379,15 +382,18 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) { Set containsAndClosed = Sets.newHashSet(expectedError); containsAndClosed.add("closed"); containsAndClosed.add("Connection reset"); + containsAndClosed.add("java.nio.channels.ClosedChannelException"); Pair, Set> r = checkErrorsContain(errors, containsAndClosed); - Set errorsNotFound = r.getRight(); - assertEquals(1, errorsNotFound.size()); - String err = errorsNotFound.iterator().next(); - assertTrue(err.equals("closed") || err.equals("Connection reset")); + assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty()); - assertTrue(r.getLeft().isEmpty()); + Set errorsNotFound = r.getRight(); + assertEquals( + "The size of " + errorsNotFound.toString() + " was not 2", 2, errorsNotFound.size()); + for (String err: errorsNotFound) { + assertTrue("Found a wrong error " + err, containsAndClosed.contains(err)); + } } private Pair, Set> checkErrorsContain( From 058c4602b000b24deb764a810ef8b43c41fe63ae Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 16 Nov 2018 15:43:27 -0800 Subject: [PATCH 061/145] [SPARK-26092][SS] Use CheckpointFileManager to write the streaming metadata file ## What changes were proposed in this pull request? Use CheckpointFileManager to write the streaming `metadata` file so that the `metadata` file will never be a partial file. ## How was this patch tested? Jenkins Closes #23060 from zsxwing/SPARK-26092. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../streaming/CheckpointFileManager.scala | 2 +- .../execution/streaming/StreamExecution.scala | 1 + .../execution/streaming/StreamMetadata.scala | 23 +++++++++++++------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala index 606ba250ad9d2..b3e4240c315bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -56,7 +56,7 @@ trait CheckpointFileManager { * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to * overwrite the file if it already exists. It should not throw * any exception if the file exists. However, if false, then the - * implementation must not overwrite if the file alraedy exists and + * implementation must not overwrite if the file already exists and * must throw `FileAlreadyExistsException` in that case. */ def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 631a6eb649ffb..89b4f40c9c0b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -88,6 +88,7 @@ abstract class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + fs.mkdirs(checkpointPath) checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala index 0bc54eac4ee8e..516afbea5d9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala @@ -19,16 +19,18 @@ package org.apache.spark.sql.execution.streaming import java.io.{InputStreamReader, OutputStreamWriter} import java.nio.charset.StandardCharsets +import java.util.ConcurrentModificationException import scala.util.control.NonFatal import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataInputStream, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.streaming.StreamingQuery /** @@ -70,19 +72,26 @@ object StreamMetadata extends Logging { metadata: StreamMetadata, metadataFile: Path, hadoopConf: Configuration): Unit = { - var output: FSDataOutputStream = null + var output: CancellableFSDataOutputStream = null try { - val fs = metadataFile.getFileSystem(hadoopConf) - output = fs.create(metadataFile) + val fileManager = CheckpointFileManager.create(metadataFile.getParent, hadoopConf) + output = fileManager.createAtomic(metadataFile, overwriteIfPossible = false) val writer = new OutputStreamWriter(output) Serialization.write(metadata, writer) writer.close() } catch { - case NonFatal(e) => + case e: FileAlreadyExistsException => + if (output != null) { + output.cancel() + } + throw new ConcurrentModificationException( + s"Multiple streaming queries are concurrently using $metadataFile", e) + case e: Throwable => + if (output != null) { + output.cancel() + } logError(s"Error writing stream metadata $metadata to $metadataFile", e) throw e - } finally { - IOUtils.closeQuietly(output) } } } From d2792046a1b10a07b65fc30be573983f1237e450 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 16 Nov 2018 15:57:38 -0800 Subject: [PATCH 062/145] [SPARK-26095][BUILD] Disable parallelization in make-distibution.sh. It makes the build slower, but at least it doesn't hang. Seems that maven-shade-plugin has some issue with parallelization. Closes #23061 from vanzin/SPARK-26095. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- dev/make-distribution.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84f4ae9a64ff8..a550af93feecd 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -165,7 +165,7 @@ export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=512m}" # Store the command as an array because $MVN variable might have spaces in it. # Normal quoting tricks don't work. # See: http://mywiki.wooledge.org/BashFAQ/050 -BUILD_COMMAND=("$MVN" -T 1C clean package -DskipTests $@) +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." From 23cd0e6e9e20a224a71859c158437e0a31982259 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 17 Nov 2018 15:07:20 +0800 Subject: [PATCH 063/145] [SPARK-26079][SQL] Ensure listener event delivery in StreamingQueryListenersConfSuite. Events are dispatched on a separate thread, so need to wait for them to be actually delivered before checking that the listener got them. Closes #23050 from vanzin/SPARK-26079. Authored-by: Marcelo Vanzin Signed-off-by: hyukjinkwon --- .../spark/sql/streaming/StreamingQueryListenersConfSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala index 1aaf8a9aa2d55..ddbc175e7ea48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -30,7 +30,6 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { import testImplicits._ - override protected def sparkConf: SparkConf = super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", "org.apache.spark.sql.streaming.TestListener") @@ -41,6 +40,8 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { StopStream ) + spark.sparkContext.listenerBus.waitUntilEmpty(5000) + assert(TestListener.queryStartedEvent != null) assert(TestListener.queryTerminatedEvent != null) } From b538c442cb3982cc4c3aac812a7d4764209dfbb7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Nov 2018 18:18:41 +0800 Subject: [PATCH 064/145] [MINOR][SQL] Fix typo in CTAS plan database string ## What changes were proposed in this pull request? Since [Spark 1.6.0](https://github.com/apache/spark/commit/56d7da14ab8f89bf4f303b27f51fd22d23967ffb#diff-6f38a103058a6e233b7ad80718452387R96), there was a redundant '}' character in CTAS string plan's database argument string; `default}`. This PR aims to fix it. **BEFORE** ```scala scala> sc.version res1: String = 1.6.0 scala> sql("create table t as select 1").explain == Physical Plan == ExecutedCommand CreateTableAsSelect [Database:default}, TableName: t, InsertIntoHiveTable] +- Project [1 AS _c0#3] +- OneRowRelation$ ``` **AFTER** ```scala scala> sql("create table t as select 1").explain == Physical Plan == Execute CreateHiveTableAsSelectCommand CreateHiveTableAsSelectCommand [Database:default, TableName: t, InsertIntoHiveTable] +- *(1) Project [1 AS 1#4] +- Scan OneRowRelation[] ``` ## How was this patch tested? Manual. Closes #23064 from dongjoon-hyun/SPARK-FIX. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- .../sql/hive/execution/CreateHiveTableAsSelectCommand.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index aa573b54a2b62..630bea5161f19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -96,7 +96,7 @@ case class CreateHiveTableAsSelectCommand( } override def argString: String = { - s"[Database:${tableDesc.database}}, " + + s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } From ed46ac9f4736d23c2f7294133d4def93dc99cce1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Nov 2018 03:28:43 -0800 Subject: [PATCH 065/145] [SPARK-26091][SQL] Upgrade to 2.3.4 for Hive Metastore Client 2.3 ## What changes were proposed in this pull request? [Hive 2.3.4 is released on Nov. 7th](https://hive.apache.org/downloads.html#7-november-2018-release-234-available). This PR aims to support that version. ## How was this patch tested? Pass the Jenkins with the updated version Closes #23059 from dongjoon-hyun/SPARK-26091. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/sql-data-sources-hive-tables.md | 2 +- docs/sql-migration-guide-hive-compatibility.md | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../org/apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/package.scala | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index 687e6f8e0a7cc..28e1a39626666 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -115,7 +115,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.3. + options are 0.12.0 through 2.3.4. diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md index 94849418030ef..dd7b06225714f 100644 --- a/docs/sql-migration-guide-hive-compatibility.md +++ b/docs/sql-migration-guide-hive-compatibility.md @@ -10,7 +10,7 @@ displayTitle: Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.4. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 74f21532b22df..66067704195dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.3.") + s"0.12.0 through 2.3.4.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c1d8fe53a9e8c..f56ca8cb08553 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -99,7 +99,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" | "2.3.4" => hive.v2_3 case version => throw new UnsupportedOperationException(s"Unsupported Hive Metastore version ($version). " + s"Please set ${HiveUtils.HIVE_METASTORE_VERSION.key} with a valid version.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 25e9886fa6576..e4cf7299d2af6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.3", + case object v2_3 extends HiveVersion("2.3.4", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) From e557c53c59a98f601b15850bb89fd4e252135556 Mon Sep 17 00:00:00 2001 From: Shahid Date: Sat, 17 Nov 2018 09:43:33 -0600 Subject: [PATCH 066/145] [SPARK-26006][MLLIB] unpersist 'dataInternalRepr' in the PrefixSpan ## What changes were proposed in this pull request? Mllib's Prefixspan - run method - cached RDD stays in cache. After run is comlpeted , rdd remain in cache. We need to unpersist the cached RDD after run method. ## How was this patch tested? Existing tests Closes #23016 from shahidki31/SPARK-26006. Authored-by: Shahid Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 7aed2f3bd8a61..64d6a0bc47b97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -174,6 +174,13 @@ class PrefixSpan private ( val freqSequences = results.map { case (seq: Array[Int], count: Long) => new FreqSequence(toPublicRepr(seq), count) } + // Cache the final RDD to the same storage level as input + if (data.getStorageLevel != StorageLevel.NONE) { + freqSequences.persist(data.getStorageLevel) + freqSequences.count() + } + dataInternalRepr.unpersist(false) + new PrefixSpanModel(freqSequences) } From e00cac989821aea238c7bf20b69068ef7cf2eef3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 17 Nov 2018 09:46:45 -0600 Subject: [PATCH 067/145] [SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading ## What changes were proposed in this pull request? Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones. The PR changes the `impurity` param used to one which allows only the value `variance`. ## How was this patch tested? modified UT Closes #22986 from mgaido91/SPARK-25959. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../ml/classification/GBTClassifier.scala | 4 +++- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 19 ++++++++++--------- .../classification/GBTClassifierSuite.scala | 1 + project/MimaExcludes.scala | 11 +++++++++++ 6 files changed, 27 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39746ff0..62c6bdbdeb285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - metadata.getAndSetParams(model) + // We ignore the impurity while loading models because in previous models it was wrongly + // set to gini (see SPARK-25959). + metadata.getAndSetParams(model, Some(List("impurity"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6fa656275c1fd..c9de85de42fa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("1.4.0") object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities @Since("2.0.0") override def load(path: String): DecisionTreeRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82bf66ff66d8a..66d57ad6c4348 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 00157fe63af41..f1e3836ebe476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams { private[ml] trait DecisionTreeClassifierParams extends DecisionTreeParams with TreeClassifierParams -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - +private[ml] trait HasVarianceImpurity extends Params { /** * Criterion used for information gain calculation (case-insensitive). * Supported: "variance". @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}", (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params { } } -private[ml] object TreeRegressorParams { +private[ml] object HasVarianceImpurity { // These options should be lowercase. final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase(Locale.ROOT)) } +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends HasVarianceImpurity + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { +private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity { /** * Loss function which GBT tries to minimize. (case-insensitive) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 304977634189c..cedbaf1858ef4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model2: GBTClassificationModel): Unit = { TreeTests.checkEqual(model, model2) assert(model.numFeatures === model2.numFeatures) + assert(model.featureImportances == model2.featureImportances) } val gbt = new GBTClassifier() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b030b6ca2922f..a8d2b5d1d9cb6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,17 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), From 034ae305c33b1990b3c1a284044002874c343b4d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 18 Nov 2018 16:02:15 +0800 Subject: [PATCH 068/145] [SPARK-26033][PYTHON][TESTS] Break large ml/tests.py file into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR breaks down the large ml/tests.py file that contains all Python ML unit tests into several smaller test files to be easier to read and maintain. The tests are broken down as follows: ``` pyspark ├── __init__.py ... ├── ml │ ├── __init__.py ... │ ├── tests │ │ ├── __init__.py │ │ ├── test_algorithms.py │ │ ├── test_base.py │ │ ├── test_evaluation.py │ │ ├── test_feature.py │ │ ├── test_image.py │ │ ├── test_linalg.py │ │ ├── test_param.py │ │ ├── test_persistence.py │ │ ├── test_pipeline.py │ │ ├── test_stat.py │ │ ├── test_training_summary.py │ │ ├── test_tuning.py │ │ └── test_wrapper.py ... ├── testing ... │ ├── mlutils.py ... ``` ## How was this patch tested? Ran tests manually by module to ensure test count was the same, and ran `python/run-tests --modules=pyspark-ml` to verify all passing with Python 2.7 and Python 3.6. Closes #23063 from BryanCutler/python-test-breakup-ml-SPARK-26033. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 16 +- python/pyspark/ml/tests.py | 2762 ----------------- python/pyspark/ml/tests/__init__.py | 16 + python/pyspark/ml/tests/test_algorithms.py | 349 +++ python/pyspark/ml/tests/test_base.py | 85 + python/pyspark/ml/tests/test_evaluation.py | 71 + python/pyspark/ml/tests/test_feature.py | 318 ++ python/pyspark/ml/tests/test_image.py | 118 + python/pyspark/ml/tests/test_linalg.py | 392 +++ python/pyspark/ml/tests/test_param.py | 372 +++ python/pyspark/ml/tests/test_persistence.py | 369 +++ python/pyspark/ml/tests/test_pipeline.py | 77 + python/pyspark/ml/tests/test_stat.py | 58 + .../pyspark/ml/tests/test_training_summary.py | 258 ++ python/pyspark/ml/tests/test_tuning.py | 552 ++++ python/pyspark/ml/tests/test_wrapper.py | 120 + python/pyspark/testing/mlutils.py | 161 + 17 files changed, 3331 insertions(+), 2763 deletions(-) delete mode 100755 python/pyspark/ml/tests.py create mode 100644 python/pyspark/ml/tests/__init__.py create mode 100644 python/pyspark/ml/tests/test_algorithms.py create mode 100644 python/pyspark/ml/tests/test_base.py create mode 100644 python/pyspark/ml/tests/test_evaluation.py create mode 100644 python/pyspark/ml/tests/test_feature.py create mode 100644 python/pyspark/ml/tests/test_image.py create mode 100644 python/pyspark/ml/tests/test_linalg.py create mode 100644 python/pyspark/ml/tests/test_param.py create mode 100644 python/pyspark/ml/tests/test_persistence.py create mode 100644 python/pyspark/ml/tests/test_pipeline.py create mode 100644 python/pyspark/ml/tests/test_stat.py create mode 100644 python/pyspark/ml/tests/test_training_summary.py create mode 100644 python/pyspark/ml/tests/test_tuning.py create mode 100644 python/pyspark/ml/tests/test_wrapper.py create mode 100644 python/pyspark/testing/mlutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 547635a412913..eef7f259391b8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -452,6 +452,7 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ + # doctests "pyspark.ml.classification", "pyspark.ml.clustering", "pyspark.ml.evaluation", @@ -463,7 +464,20 @@ def __hash__(self): "pyspark.ml.regression", "pyspark.ml.stat", "pyspark.ml.tuning", - "pyspark.ml.tests", + # unittests + "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_base", + "pyspark.ml.tests.test_evaluation", + "pyspark.ml.tests.test_feature", + "pyspark.ml.tests.test_image", + "pyspark.ml.tests.test_linalg", + "pyspark.ml.tests.test_param", + "pyspark.ml.tests.test_persistence", + "pyspark.ml.tests.test_pipeline", + "pyspark.ml.tests.test_stat", + "pyspark.ml.tests.test_training_summary", + "pyspark.ml.tests.test_tuning", + "pyspark.ml.tests.test_wrapper", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py deleted file mode 100755 index 2b4b7315d98c0..0000000000000 --- a/python/pyspark/ml/tests.py +++ /dev/null @@ -1,2762 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for MLlib Python DataFrame-based APIs. -""" -import sys -if sys.version > '3': - xrange = range - basestring = str - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from shutil import rmtree -import tempfile -import array as pyarray -import numpy as np -from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros -import inspect -import py4j - -from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer -from pyspark.ml.classification import * -from pyspark.ml.clustering import * -from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ - MulticlassClassificationEvaluator, RegressionEvaluator -from pyspark.ml.feature import * -from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.image import ImageSchema -from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors -from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed -from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ - LinearRegression -from pyspark.ml.stat import ChiSquareTest -from pyspark.ml.tuning import * -from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession, HiveContext -from pyspark.sql.functions import rand -from pyspark.sql.types import DoubleType, IntegerType -from pyspark.storagelevel import * -from pyspark.testing.utils import QuietTest, ReusedPySparkTestCase as PySparkTestCase - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class SparkSessionTestCase(PySparkTestCase): - @classmethod - def setUpClass(cls): - PySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.stop() - - -class MockDataset(DataFrame): - - def __init__(self): - self.index = 0 - - -class HasFake(Params): - - def __init__(self): - super(HasFake, self).__init__() - self.fake = Param(self, "fake", "fake param") - - def getFake(self): - return self.getOrDefault(self.fake) - - -class MockTransformer(Transformer, HasFake): - - def __init__(self): - super(MockTransformer, self).__init__() - self.dataset_index = None - - def _transform(self, dataset): - self.dataset_index = dataset.index - dataset.index += 1 - return dataset - - -class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): - - shift = Param(Params._dummy(), "shift", "The amount by which to shift " + - "data in a DataFrame", - typeConverter=TypeConverters.toFloat) - - def __init__(self, shiftVal=1): - super(MockUnaryTransformer, self).__init__() - self._setDefault(shift=1) - self._set(shift=shiftVal) - - def getShift(self): - return self.getOrDefault(self.shift) - - def setShift(self, shift): - self._set(shift=shift) - - def createTransformFunc(self): - shiftVal = self.getShift() - return lambda x: x + shiftVal - - def outputDataType(self): - return DoubleType() - - def validateInputType(self, inputType): - if inputType != DoubleType(): - raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Double.") - - -class MockEstimator(Estimator, HasFake): - - def __init__(self): - super(MockEstimator, self).__init__() - self.dataset_index = None - - def _fit(self, dataset): - self.dataset_index = dataset.index - model = MockModel() - self._copyValues(model) - return model - - -class MockModel(MockTransformer, Model, HasFake): - pass - - -class JavaWrapperMemoryTests(SparkSessionTestCase): - - def test_java_object_gets_detached(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - - model = lr.fit(df) - summary = model.summary - - self.assertIsInstance(model, JavaWrapper) - self.assertIsInstance(summary, JavaWrapper) - self.assertIsInstance(model, JavaParams) - self.assertNotIsInstance(summary, JavaParams) - - error_no_object = 'Target Object ID does not exist for this gateway' - - self.assertIn("LinearRegression_", model._java_obj.toString()) - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - model.__del__() - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - try: - summary.__del__() - except: - pass - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - summary._java_obj.toString() - - -class ParamTypeConversionTests(PySparkTestCase): - """ - Test that param type conversion happens. - """ - - def test_int(self): - lr = LogisticRegression(maxIter=5.0) - self.assertEqual(lr.getMaxIter(), 5) - self.assertTrue(type(lr.getMaxIter()) == int) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) - - def test_float(self): - lr = LogisticRegression(tol=1) - self.assertEqual(lr.getTol(), 1.0) - self.assertTrue(type(lr.getTol()) == float) - self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) - - def test_vector(self): - ewp = ElementwiseProduct(scalingVec=[1, 3]) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) - ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) - self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) - - def test_list(self): - l = [0, 1] - for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), - range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]: - converted = TypeConverters.toList(lst_like) - self.assertEqual(type(converted), list) - self.assertListEqual(converted, l) - - def test_list_int(self): - for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), - SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), - pyarray.array('d', [1.0, 2.0])]: - vs = VectorSlicer(indices=indices) - self.assertListEqual(vs.getIndices(), [1, 2]) - self.assertTrue(all([type(v) == int for v in vs.getIndices()])) - self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) - - def test_list_float(self): - b = Bucketizer(splits=[1, 4]) - self.assertEqual(b.getSplits(), [1.0, 4.0]) - self.assertTrue(all([type(v) == float for v in b.getSplits()])) - self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) - - def test_list_string(self): - for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: - idx_to_string = IndexToString(labels=labels) - self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) - self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) - - def test_string(self): - lr = LogisticRegression() - for col in ['features', u'features', np.str_('features')]: - lr.setFeaturesCol(col) - self.assertEqual(lr.getFeaturesCol(), 'features') - self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) - - def test_bool(self): - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - - -class PipelineTests(PySparkTestCase): - - def test_pipeline(self): - dataset = MockDataset() - estimator0 = MockEstimator() - transformer1 = MockTransformer() - estimator2 = MockEstimator() - transformer3 = MockTransformer() - pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) - pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - model0, transformer1, model2, transformer3 = pipeline_model.stages - self.assertEqual(0, model0.dataset_index) - self.assertEqual(0, model0.getFake()) - self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.getFake()) - self.assertEqual(2, dataset.index) - self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") - self.assertIsNone(transformer3.dataset_index, - "The last transformer shouldn't be called in fit.") - dataset = pipeline_model.transform(dataset) - self.assertEqual(2, model0.dataset_index) - self.assertEqual(3, transformer1.dataset_index) - self.assertEqual(4, model2.dataset_index) - self.assertEqual(5, transformer3.dataset_index) - self.assertEqual(6, dataset.index) - - def test_identity_pipeline(self): - dataset = MockDataset() - - def doTransform(pipeline): - pipeline_model = pipeline.fit(dataset) - return pipeline_model.transform(dataset) - # check that empty pipeline did not perform any transformation - self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) - # check that failure to set stages param will raise KeyError for missing param - self.assertRaises(KeyError, lambda: doTransform(Pipeline())) - - -class TestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(TestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(OtherTestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") - - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") - - -class ParamTests(SparkSessionTestCase): - - def test_copy_new_parent(self): - testParams = TestParams() - # Copying an instantiated param should fail - with self.assertRaises(ValueError): - testParams.maxIter._copy_new_parent(testParams) - # Copying a dummy param should succeed - TestParams.maxIter._copy_new_parent(testParams) - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_param(self): - testParams = TestParams() - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_hasparam(self): - testParams = TestParams() - self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) - self.assertFalse(testParams.hasParam("notAParameter")) - self.assertTrue(testParams.hasParam(u"maxIter")) - - def test_resolveparam(self): - testParams = TestParams() - self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) - self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) - - self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) - if sys.version_info[0] >= 3: - # In Python 3, it is allowed to get/set attributes with non-ascii characters. - e_cls = AttributeError - else: - e_cls = UnicodeEncodeError - self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) - - def test_params(self): - testParams = TestParams() - maxIter = testParams.maxIter - inputCol = testParams.inputCol - seed = testParams.seed - - params = testParams.params - self.assertEqual(params, [inputCol, maxIter, seed]) - - self.assertTrue(testParams.hasParam(maxIter.name)) - self.assertTrue(testParams.hasDefault(maxIter)) - self.assertFalse(testParams.isSet(maxIter)) - self.assertTrue(testParams.isDefined(maxIter)) - self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) - self.assertTrue(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 100) - - self.assertTrue(testParams.hasParam(inputCol.name)) - self.assertFalse(testParams.hasDefault(inputCol)) - self.assertFalse(testParams.isSet(inputCol)) - self.assertFalse(testParams.isDefined(inputCol)) - with self.assertRaises(KeyError): - testParams.getInputCol() - - otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + - "set raises an error for a non-member parameter.", - typeConverter=TypeConverters.toString) - with self.assertRaises(ValueError): - testParams.set(otherParam, "value") - - # Since the default is normally random, set it to a known number for debug str - testParams._setDefault(seed=41) - testParams.setSeed(43) - - self.assertEqual( - testParams.explainParams(), - "\n".join(["inputCol: input column name. (undefined)", - "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", - "seed: random seed. (default: 41, current: 43)"])) - - def test_kmeans_param(self): - algo = KMeans() - self.assertEqual(algo.getInitMode(), "k-means||") - algo.setK(10) - self.assertEqual(algo.getK(), 10) - algo.setInitSteps(10) - self.assertEqual(algo.getInitSteps(), 10) - self.assertEqual(algo.getDistanceMeasure(), "euclidean") - algo.setDistanceMeasure("cosine") - self.assertEqual(algo.getDistanceMeasure(), "cosine") - - def test_hasseed(self): - noSeedSpecd = TestParams() - withSeedSpecd = TestParams(seed=42) - other = OtherTestParams() - # Check that we no longer use 42 as the magic number - self.assertNotEqual(noSeedSpecd.getSeed(), 42) - origSeed = noSeedSpecd.getSeed() - # Check that we only compute the seed once - self.assertEqual(noSeedSpecd.getSeed(), origSeed) - # Check that a specified seed is honored - self.assertEqual(withSeedSpecd.getSeed(), 42) - # Check that a different class has a different seed - self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) - - def test_param_property_error(self): - param_store = HasThrowableProperty() - self.assertRaises(RuntimeError, lambda: param_store.test_property) - params = param_store.params # should not invoke the property 'test_property' - self.assertEqual(len(params), 1) - - def test_word2vec_param(self): - model = Word2Vec().setWindowSize(6) - # Check windowSize is set properly - self.assertEqual(model.getWindowSize(), 6) - - def test_copy_param_extras(self): - tp = TestParams(seed=42) - extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} - tp_copy = tp.copy(extra=extra) - self.assertEqual(tp.uid, tp_copy.uid) - self.assertEqual(tp.params, tp_copy.params) - for k, v in extra.items(): - self.assertTrue(tp_copy.isDefined(k)) - self.assertEqual(tp_copy.getOrDefault(k), v) - copied_no_extra = {} - for k, v in tp_copy._paramMap.items(): - if k not in extra: - copied_no_extra[k] = v - self.assertEqual(tp._paramMap, copied_no_extra) - self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) - - def test_logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegression - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - - def test_preserve_set_state(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - self.assertFalse(binarizer.isSet("threshold")) - binarizer.transform(dataset) - binarizer._transfer_params_from_java() - self.assertFalse(binarizer.isSet("threshold"), - "Params not explicitly set should remain unset after transform") - - def test_default_params_transferred(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - # intentionally change the pyspark default, but don't set it - binarizer._defaultParamMap[binarizer.outputCol] = "my_default" - result = binarizer.transform(dataset).select("my_default").collect() - self.assertFalse(binarizer.isSet(binarizer.outputCol)) - self.assertEqual(result[0][0], 1.0) - - @staticmethod - def check_params(test_self, py_stage, check_params_exist=True): - """ - Checks common requirements for Params.params: - - set of params exist in Java and Python and are ordered by names - - param parent has the same UID as the object's UID - - default param value from Java matches value in Python - - optionally check if all params from Java also exist in Python - """ - py_stage_str = "%s %s" % (type(py_stage), py_stage) - if not hasattr(py_stage, "_to_java"): - return - java_stage = py_stage._to_java() - if java_stage is None: - return - test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) - if check_params_exist: - param_names = [p.name for p in py_stage.params] - java_params = list(java_stage.params()) - java_param_names = [jp.name() for jp in java_params] - test_self.assertEqual( - param_names, sorted(java_param_names), - "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" - % (py_stage_str, java_param_names, param_names)) - for p in py_stage.params: - test_self.assertEqual(p.parent, py_stage.uid) - java_param = java_stage.getParam(p.name) - py_has_default = py_stage.hasDefault(p) - java_has_default = java_stage.hasDefault(java_param) - test_self.assertEqual(py_has_default, java_has_default, - "Default value mismatch of param %s for Params %s" - % (p.name, str(py_stage))) - if py_has_default: - if p.name == "seed": - continue # Random seeds between Spark and PySpark are different - java_default = _java2py(test_self.sc, - java_stage.clear(java_param).getOrDefault(java_param)) - py_stage._clear(p) - py_default = py_stage.getOrDefault(p) - # equality test for NaN is always False - if isinstance(java_default, float) and np.isnan(java_default): - java_default = "NaN" - py_default = "NaN" if np.isnan(py_default) else "not NaN" - test_self.assertEqual( - java_default, py_default, - "Java default %s != python default %s of param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - - -class EvaluatorTests(SparkSessionTestCase): - - def test_java_params(self): - """ - This tests a bug fixed by SPARK-18274 which causes multiple copies - of a Params instance in Python to be linked to the same Java instance. - """ - evaluator = RegressionEvaluator(metricName="r2") - df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) - evaluator.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) - evaluator.evaluate(df) - evaluatorCopy.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") - - def test_clustering_evaluator_with_cosine_distance(self): - featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), - [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), - ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) - dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) - evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") - self.assertEqual(evaluator.getDistanceMeasure(), "cosine") - self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) - - -class FeatureTests(SparkSessionTestCase): - - def test_binarizer(self): - b0 = Binarizer() - self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) - self.assertTrue(all([~b0.isSet(p) for p in b0.params])) - self.assertTrue(b0.hasDefault(b0.threshold)) - self.assertEqual(b0.getThreshold(), 0.0) - b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) - self.assertTrue(all([b0.isSet(p) for p in b0.params])) - self.assertEqual(b0.getThreshold(), 1.0) - self.assertEqual(b0.getInputCol(), "input") - self.assertEqual(b0.getOutputCol(), "output") - - b0c = b0.copy({b0.threshold: 2.0}) - self.assertEqual(b0c.uid, b0.uid) - self.assertListEqual(b0c.params, b0.params) - self.assertEqual(b0c.getThreshold(), 2.0) - - b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") - self.assertNotEqual(b1.uid, b0.uid) - self.assertEqual(b1.getThreshold(), 2.0) - self.assertEqual(b1.getInputCol(), "input") - self.assertEqual(b1.getOutputCol(), "output") - - def test_idf(self): - dataset = self.spark.createDataFrame([ - (DenseVector([1.0, 2.0]),), - (DenseVector([0.0, 1.0]),), - (DenseVector([3.0, 0.2]),)], ["tf"]) - idf0 = IDF(inputCol="tf") - self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) - idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) - self.assertEqual(idf0m.uid, idf0.uid, - "Model should inherit the UID from its parent estimator.") - output = idf0m.transform(dataset) - self.assertIsNotNone(output.head().idf) - # Test that parameters transferred to Python Model - ParamTests.check_params(self, idf0m) - - def test_ngram(self): - dataset = self.spark.createDataFrame([ - Row(input=["a", "b", "c", "d", "e"])]) - ngram0 = NGram(n=4, inputCol="input", outputCol="output") - self.assertEqual(ngram0.getN(), 4) - self.assertEqual(ngram0.getInputCol(), "input") - self.assertEqual(ngram0.getOutputCol(), "output") - transformedDF = ngram0.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) - - def test_stopwordsremover(self): - dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) - stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") - # Default - self.assertEqual(stopWordRemover.getInputCol(), "input") - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["panda"]) - self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) - # Custom - stopwords = ["panda"] - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getInputCol(), "input") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - # with locale - stopwords = ["BELKİ"] - dataset = self.spark.createDataFrame([Row(input=["belki"])]) - stopWordRemover.setStopWords(stopwords).setLocale("tr") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - - def test_count_vectorizer_with_binary(self): - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) - cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") - model = cv.fit(dataset) - - transformedList = model.transform(dataset).select("features", "expected").collect() - - for r in transformedList: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_with_maxDF(self): - dataset = self.spark.createDataFrame([ - (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), - (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - cv = CountVectorizer(inputCol="words", outputCol="features") - model1 = cv.setMaxDF(3).fit(dataset) - self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) - - transformedList1 = model1.transform(dataset).select("features", "expected").collect() - - for r in transformedList1: - feature, expected = r - self.assertEqual(feature, expected) - - model2 = cv.setMaxDF(0.75).fit(dataset) - self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) - - transformedList2 = model2.transform(dataset).select("features", "expected").collect() - - for r in transformedList2: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_from_vocab(self): - model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", - outputCol="features", minTF=2) - self.assertEqual(model.vocabulary, ["a", "b", "c"]) - self.assertEqual(model.getMinTF(), 2) - - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), - (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - - transformed_list = model.transform(dataset).select("features", "expected").collect() - - for r in transformed_list: - feature, expected = r - self.assertEqual(feature, expected) - - # Test an empty vocabulary - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): - CountVectorizerModel.from_vocabulary([], inputCol="words") - - # Test model with default settings can transform - model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") - transformed_list = model_default.transform(dataset)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 3) - - def test_rformula_force_index_label(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - # Does not index label by default since it's numeric type. - rf = RFormula(formula="y ~ x + s") - model = rf.fit(df) - transformedDF = model.transform(df) - self.assertEqual(transformedDF.head().label, 1.0) - # Force to index label. - rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) - model2 = rf2.fit(df) - transformedDF2 = model2.transform(df) - self.assertEqual(transformedDF2.head().label, 0.0) - - def test_rformula_string_indexer_order_type(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") - self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') - transformedDF = rf.fit(df).transform(df) - observed = transformedDF.select("features").collect() - expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] - for i in range(0, len(expected)): - self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) - - def test_string_indexer_handle_invalid(self): - df = self.spark.createDataFrame([ - (0, "a"), - (1, "d"), - (2, None)], ["id", "label"]) - - si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", - stringOrderType="alphabetAsc") - model1 = si1.fit(df) - td1 = model1.transform(df) - actual1 = td1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] - self.assertEqual(actual1, expected1) - - si2 = si1.setHandleInvalid("skip") - model2 = si2.fit(df) - td2 = model2.transform(df) - actual2 = td2.select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] - self.assertEqual(actual2, expected2) - - def test_string_indexer_from_labels(self): - model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", - outputCol="indexed", handleInvalid="keep") - self.assertEqual(model.labels, ["a", "b", "c"]) - - df1 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, None), - (3, "b"), - (4, "b")], ["id", "label"]) - - result1 = model.transform(df1) - actual1 = result1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), - Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] - self.assertEqual(actual1, expected1) - - model_empty_labels = StringIndexerModel.from_labels( - [], inputCol="label", outputCol="indexed", handleInvalid="keep") - actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), - Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] - self.assertEqual(actual2, expected2) - - # Test model with default settings can transform - model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") - df2 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, "b"), - (3, "b"), - (4, "b")], ["id", "label"]) - transformed_list = model_default.transform(df2)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 5) - - def test_vector_size_hint(self): - df = self.spark.createDataFrame( - [(0, Vectors.dense([0.0, 10.0, 0.5])), - (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), - (2, Vectors.dense([2.0, 12.0]))], - ["id", "vector"]) - - sizeHint = VectorSizeHint( - inputCol="vector", - handleInvalid="skip") - sizeHint.setSize(3) - self.assertEqual(sizeHint.getSize(), 3) - - output = sizeHint.transform(df).head().vector - expected = DenseVector([0.0, 10.0, 0.5]) - self.assertEqual(output, expected) - - -class HasInducedError(Params): - - def __init__(self): - super(HasInducedError, self).__init__() - self.inducedError = Param(self, "inducedError", - "Uniformly-distributed error added to feature") - - def getInducedError(self): - return self.getOrDefault(self.inducedError) - - -class InducedErrorModel(Model, HasInducedError): - - def __init__(self): - super(InducedErrorModel, self).__init__() - - def _transform(self, dataset): - return dataset.withColumn("prediction", - dataset.feature + (rand(0) * self.getInducedError())) - - -class InducedErrorEstimator(Estimator, HasInducedError): - - def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() - self._set(inducedError=inducedError) - - def _fit(self, dataset): - model = InducedErrorModel() - self._copyValues(model) - return model - - -class CrossValidatorTests(SparkSessionTestCase): - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvCopied = cv.copy() - self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) - - cvModel = cv.fit(dataset) - cvModelCopied = cvModel.copy() - for index in range(len(cvModel.avgMetrics)): - self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) - < 0.0001) - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - - def test_param_grid_type_coercion(self): - lr = LogisticRegression(maxIter=10) - paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() - for param in paramGrid: - for v in param.values(): - assert(type(v) == float) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for CrossValidator will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - lrModel = cvModel.bestModel - - cvModelPath = temp_path + "/cvModel" - lrModel.save(cvModelPath) - loadedLrModel = LogisticRegressionModel.load(cvModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cv.setParallelism(1) - cvSerialModel = cv.fit(dataset) - cv.setParallelism(2) - cvParallelModel = cv.fit(dataset) - self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - numFolds = 3 - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - numFolds=numFolds, collectSubModels=True) - - def checkSubModels(subModels): - self.assertEqual(len(subModels), numFolds) - for i in range(numFolds): - self.assertEqual(len(subModels[i]), len(grid)) - - cvModel = cv.fit(dataset) - checkSubModels(cvModel.subModels) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testCrossValidatorSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - cvModel.save(savingPathWithSubModels) - cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - checkSubModels(cvModel3.subModels) - cvModel4 = cvModel3.copy() - checkSubModels(cvModel4.subModels) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - self.assertEqual(cvModel2.subModels, None) - - for i in range(numFolds): - for j in range(len(grid)): - self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) - - def test_save_load_nested_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - - originalParamMap = cv.getEstimatorParamMaps() - loadedParamMap = loadedCV.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - -class TrainValidationSplitTests(SparkSessionTestCase): - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(0.0, min(validationMetrics)) - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(1.0, max(validationMetrics)) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - lrModel = tvsModel.bestModel - - tvsModelPath = temp_path + "/tvsModel" - lrModel.save(tvsModelPath) - loadedLrModel = LogisticRegressionModel.load(tvsModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvs.setParallelism(1) - tvsSerialModel = tvs.fit(dataset) - tvs.setParallelism(2) - tvsParallelModel = tvs.fit(dataset) - self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - collectSubModels=True) - tvsModel = tvs.fit(dataset) - self.assertEqual(len(tvsModel.subModels), len(grid)) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testTrainValidationSplitSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - tvsModel.save(savingPathWithSubModels) - tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - self.assertEqual(len(tvsModel3.subModels), len(grid)) - tvsModel4 = tvsModel3.copy() - self.assertEqual(len(tvsModel4.subModels), len(grid)) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - self.assertEqual(tvsModel2.subModels, None) - - for i in range(len(grid)): - self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) - - def test_save_load_nested_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - - originalParamMap = tvs.getEstimatorParamMaps() - loadedParamMap = loadedTvs.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsCopied = tvs.copy() - tvsModelCopied = tvsModel.copy() - - self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, - "Copied TrainValidationSplit has the same uid of Estimator") - - self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) - self.assertEqual(len(tvsModel.validationMetrics), - len(tvsModelCopied.validationMetrics), - "Copied validationMetrics has the same size of the original") - for index in range(len(tvsModel.validationMetrics)): - self.assertEqual(tvsModel.validationMetrics[index], - tvsModelCopied.validationMetrics[index]) - - -class PersistenceTest(SparkSessionTestCase): - - def test_linear_regression(self): - lr = LinearRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/lr" - lr.save(lr_path) - lr2 = LinearRegression.load(lr_path) - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(type(lr.uid), type(lr2.uid)) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LinearRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_linear_regression_pmml_basic(self): - # Most of the validation is done in the Scala side, here we just check - # that we output text rather than parquet (e.g. that the format flag - # was respected). - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1) - model = lr.fit(df) - path = tempfile.mkdtemp() - lr_path = path + "/lr-pmml" - model.write().format("pmml").save(lr_path) - pmml_text_list = self.sc.textFile(lr_path).collect() - pmml_text = "\n".join(pmml_text_list) - self.assertIn("Apache Spark", pmml_text) - self.assertIn("PMML", pmml_text) - - def test_logistic_regression(self): - lr = LogisticRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/logreg" - lr.save(lr_path) - lr2 = LogisticRegression.load(lr_path) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LogisticRegression instance uid (%s) " - "did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LogisticRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def _compare_params(self, m1, m2, param): - """ - Compare 2 ML Params instances for the given param, and assert both have the same param value - and parent. The param must be a parameter of m1. - """ - # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. - if m1.isDefined(param): - paramValue1 = m1.getOrDefault(param) - paramValue2 = m2.getOrDefault(m2.getParam(param.name)) - if isinstance(paramValue1, Params): - self._compare_pipelines(paramValue1, paramValue2) - else: - self.assertEqual(paramValue1, paramValue2) # for general types param - # Assert parents are equal - self.assertEqual(param.parent, m2.getParam(param.name).parent) - else: - # If m1 is not defined param, then m2 should not, too. See SPARK-14931. - self.assertFalse(m2.isDefined(m2.getParam(param.name))) - - def _compare_pipelines(self, m1, m2): - """ - Compare 2 ML types, asserting that they are equivalent. - This currently supports: - - basic types - - Pipeline, PipelineModel - - OneVsRest, OneVsRestModel - This checks: - - uid - - type - - Param values and parents - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams) or isinstance(m1, Transformer): - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - self._compare_params(m1, m2, p) - elif isinstance(m1, Pipeline): - self.assertEqual(len(m1.getStages()), len(m2.getStages())) - for s1, s2 in zip(m1.getStages(), m2.getStages()): - self._compare_pipelines(s1, s2) - elif isinstance(m1, PipelineModel): - self.assertEqual(len(m1.stages), len(m2.stages)) - for s1, s2 in zip(m1.stages, m2.stages): - self._compare_pipelines(s1, s2) - elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): - for p in m1.params: - self._compare_params(m1, m2, p) - if isinstance(m1, OneVsRestModel): - self.assertEqual(len(m1.models), len(m2.models)) - for x, y in zip(m1.models, m2.models): - self._compare_pipelines(x, y) - else: - raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) - - def test_pipeline_persistence(self): - """ - Pipeline[HashingTF, PCA] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - pl = Pipeline(stages=[tf, pca]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_nested_pipeline_persistence(self): - """ - Pipeline[HashingTF, Pipeline[PCA]] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - p0 = Pipeline(stages=[pca]) - pl = Pipeline(stages=[tf, p0]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_python_transformer_pipeline_persistence(self): - """ - Pipeline[MockUnaryTransformer, Binarizer] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.range(0, 10).toDF('input') - tf = MockUnaryTransformer(shiftVal=2)\ - .setInputCol("input").setOutputCol("shiftedInput") - tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") - pl = Pipeline(stages=[tf, tf2]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_onevsrest(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))] * 10, - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self._compare_pipelines(ovr, loadedOvr) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - self._compare_pipelines(model, loadedModel) - - def test_decisiontree_classifier(self): - dt = DecisionTreeClassifier(maxDepth=1) - path = tempfile.mkdtemp() - dtc_path = path + "/dtc" - dt.save(dtc_path) - dt2 = DecisionTreeClassifier.load(dtc_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeClassifier instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeClassifier instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_decisiontree_regressor(self): - dt = DecisionTreeRegressor(maxDepth=1) - path = tempfile.mkdtemp() - dtr_path = path + "/dtr" - dt.save(dtr_path) - dt2 = DecisionTreeClassifier.load(dtr_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeRegressor instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeRegressor instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_default_read_write(self): - temp_path = tempfile.mkdtemp() - - lr = LogisticRegression() - lr.setMaxIter(50) - lr.setThreshold(.75) - writer = DefaultParamsWriter(lr) - - savePath = temp_path + "/lr" - writer.save(savePath) - - reader = DefaultParamsReadable.read() - lr2 = reader.load(savePath) - - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) - - # test overwrite - lr.setThreshold(.8) - writer.overwrite().save(savePath) - - reader = DefaultParamsReadable.read() - lr3 = reader.load(savePath) - - self.assertEqual(lr.uid, lr3.uid) - self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) - - def test_default_read_write_default_params(self): - lr = LogisticRegression() - self.assertFalse(lr.isSet(lr.getParam("threshold"))) - - lr.setMaxIter(50) - lr.setThreshold(.75) - - # `threshold` is set by user, default param `predictionCol` is not set by user. - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - writer = DefaultParamsWriter(lr) - metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) - self.assertTrue("defaultParamMap" in metadata) - - reader = DefaultParamsReadable.read() - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - # manually create metadata without `defaultParamMap` section. - del metadata['defaultParamMap'] - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): - reader.getAndSetParams(lr, loadedMetadata) - - # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. - metadata['sparkVersion'] = '2.3.0' - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - -class LDATest(SparkSessionTestCase): - - def _compare(self, m1, m2): - """ - Temp method for comparing instances. - TODO: Replace with generic implementation once SPARK-14706 is merged. - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - if m1.isDefined(p): - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) - if isinstance(m1, LDAModel): - self.assertEqual(m1.vocabSize(), m2.vocabSize()) - self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) - - def test_persistence(self): - # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - df = self.spark.createDataFrame([ - [1, Vectors.dense([0.0, 1.0])], - [2, Vectors.sparse(2, {0: 1.0})], - ], ["id", "features"]) - # Fit model - lda = LDA(k=2, seed=1, optimizer="em") - distributedModel = lda.fit(df) - self.assertTrue(distributedModel.isDistributed()) - localModel = distributedModel.toLocal() - self.assertFalse(localModel.isDistributed()) - # Define paths - path = tempfile.mkdtemp() - lda_path = path + "/lda" - dist_model_path = path + "/distLDAModel" - local_model_path = path + "/localLDAModel" - # Test LDA - lda.save(lda_path) - lda2 = LDA.load(lda_path) - self._compare(lda, lda2) - # Test DistributedLDAModel - distributedModel.save(dist_model_path) - distributedModel2 = DistributedLDAModel.load(dist_model_path) - self._compare(distributedModel, distributedModel2) - # Test LocalLDAModel - localModel.save(local_model_path) - localModel2 = LocalLDAModel.load(local_model_path) - self._compare(localModel, localModel2) - # Clean up - try: - rmtree(path) - except OSError: - pass - - -class TrainingSummaryTest(SparkSessionTestCase): - - def test_linear_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertAlmostEqual(s.explainedVariance, 0.25, 2) - self.assertAlmostEqual(s.meanAbsoluteError, 0.0) - self.assertAlmostEqual(s.meanSquaredError, 0.0) - self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) - self.assertAlmostEqual(s.r2, 1.0, 2) - self.assertAlmostEqual(s.r2adj, 1.0, 2) - self.assertTrue(isinstance(s.residuals, DataFrame)) - self.assertEqual(s.numInstances, 2) - self.assertEqual(s.degreesOfFreedom, 1) - devResiduals = s.devianceResiduals - self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class LinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) - - def test_glr_summary(self): - from pyspark.ml.linalg import Vectors - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", - fitIntercept=False) - model = glr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.numInstances, 2) - self.assertTrue(isinstance(s.residuals(), DataFrame)) - self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - self.assertEqual(s.degreesOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedomNull, 2) - self.assertEqual(s.rank, 1) - self.assertTrue(isinstance(s.solver, basestring)) - self.assertTrue(isinstance(s.aic, float)) - self.assertTrue(isinstance(s.deviance, float)) - self.assertTrue(isinstance(s.nullDeviance, float)) - self.assertTrue(isinstance(s.dispersion, float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class GeneralizedLinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.deviance, s.deviance) - - def test_binary_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) - - def test_multiclass_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], [])), - (2.0, 2.0, Vectors.dense(2.0)), - (2.0, 2.0, Vectors.dense(1.9))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) - - def test_gaussian_mixture_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - gmm = GaussianMixture(k=2) - model = gmm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertTrue(isinstance(s.probability, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 3) - - def test_bisecting_kmeans_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - bkm = BisectingKMeans(k=2) - model = bkm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 20) - - def test_kmeans_summary(self): - data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), - (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=2, seed=1) - model = kmeans.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 1) - - -class KMeansTests(SparkSessionTestCase): - - def test_kmeans_cosine_distance(self): - data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), - (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), - (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") - model = kmeans.fit(df) - result = model.transform(df).collect() - self.assertTrue(result[0].prediction == result[1].prediction) - self.assertTrue(result[2].prediction == result[3].prediction) - self.assertTrue(result[4].prediction == result[5].prediction) - - -class OneVsRestTests(SparkSessionTestCase): - - def test_copy(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - ovr1 = ovr.copy({lr.maxIter: 10}) - self.assertEqual(ovr.getClassifier().getMaxIter(), 5) - self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) - model = ovr.fit(df) - model1 = model.copy({model.predictionCol: "indexed"}) - self.assertEqual(model1.getPredictionCol(), "indexed") - - def test_output_columns(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, parallelism=1) - model = ovr.fit(df) - output = model.transform(df) - self.assertEqual(output.columns, ["label", "features", "prediction"]) - - def test_parallelism_doesnt_change_output(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) - modelPar1 = ovrPar1.fit(df) - ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) - modelPar2 = ovrPar2.fit(df) - for i, model in enumerate(modelPar1.models): - self.assertTrue(np.allclose(model.coefficients.toArray(), - modelPar2.models[i].coefficients.toArray(), atol=1E-4)) - self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) - - def test_support_for_weightCol(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), - (1.0, Vectors.sparse(2, [], []), 1.0), - (2.0, Vectors.dense(0.5, 0.5), 1.0)], - ["label", "features", "weight"]) - # classifier inherits hasWeightCol - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, weightCol="weight") - self.assertIsNotNone(ovr.fit(df)) - # classifier doesn't inherit hasWeightCol - dt = DecisionTreeClassifier() - ovr2 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr2.fit(df)) - - -class HashingTFTest(SparkSessionTestCase): - - def test_apply_binary_term_freqs(self): - - df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) - n = 10 - hashingTF = HashingTF() - hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) - output = hashingTF.transform(df) - features = output.select("features").first().features.toArray() - expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() - for i in range(0, n): - self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(features[i])) - - -class GeneralizedLinearRegressionTest(SparkSessionTestCase): - - def test_tweedie_distribution(self): - - df = self.spark.createDataFrame( - [(1.0, Vectors.dense(0.0, 0.0)), - (1.0, Vectors.dense(1.0, 2.0)), - (2.0, Vectors.dense(0.0, 0.0)), - (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) - - glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) - - model2 = glr.setLinkPower(-1.0).fit(df) - self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) - self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) - - def test_offset(self): - - df = self.spark.createDataFrame( - [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), - (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), - (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), - (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) - - glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], - atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) - - -class LinearRegressionTest(SparkSessionTestCase): - - def test_linear_regression_with_huber_loss(self): - - data_path = "data/mllib/sample_linear_regression_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lir = LinearRegression(loss="huber", epsilon=2.0) - model = lir.fit(df) - - expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, - 1.2612, -0.333, -0.5694, -0.6311, 0.6053] - expectedIntercept = 0.1607 - expectedScale = 9.758 - - self.assertTrue( - np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) - self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) - self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) - - -class LogisticRegressionTest(SparkSessionTestCase): - - def test_binomial_logistic_regression_with_bound(self): - - df = self.spark.createDataFrame( - [(1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.0, 3.0, Vectors.dense(2.0, 1.0)), - (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) - - lor = LogisticRegression(regParam=0.01, weightCol="weight", - lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), - upperBoundsOnIntercepts=Vectors.dense(0.0)) - model = lor.fit(df) - self.assertTrue( - np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) - - def test_multinomial_logistic_regression_with_bound(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lor = LogisticRegression(regParam=0.01, - lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), - upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) - model = lor.fit(df) - expected = [[4.593, 4.5516, 9.0099, 12.2904], - [1.0, 8.1093, 7.0, 10.0], - [3.041, 5.0, 8.0, 11.0]] - for i in range(0, len(expected)): - self.assertTrue( - np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) - self.assertTrue( - np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) - - -class MultilayerPerceptronClassifierTest(SparkSessionTestCase): - - def test_raw_and_probability_prediction(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], - blockSize=128, seed=123) - model = mlp.fit(df) - test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() - result = model.transform(test).head() - expected_prediction = 2.0 - expected_probability = [0.0, 0.0, 1.0] - expected_rawPrediction = [57.3955, -124.5462, 67.9943] - self.assertTrue(result.prediction, expected_prediction) - self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) - self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) - - -class FPGrowthTests(SparkSessionTestCase): - def setUp(self): - super(FPGrowthTests, self).setUp() - self.data = self.spark.createDataFrame( - [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], - ["items"]) - - def test_association_rules(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], - ["antecedent", "consequent", "confidence", "lift"] - ) - actual_association_rules = fpm.associationRules - - self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) - self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) - - def test_freq_itemsets(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_freq_itemsets = self.spark.createDataFrame( - [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], - ["items", "freq"] - ) - actual_freq_itemsets = fpm.freqItemsets - - self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) - self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) - - def tearDown(self): - del self.data - - -class ImageReaderTest(SparkSessionTestCase): - - def test_read_images(self): - data_path = 'data/mllib/images/origin/kittens' - df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - self.assertEqual(df.count(), 4) - first_row = df.take(1)[0][0] - array = ImageSchema.toNDArray(first_row) - self.assertEqual(len(array), first_row[1]) - self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) - self.assertEqual(df.schema, ImageSchema.imageSchema) - self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) - expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} - self.assertEqual(ImageSchema.ocvTypes, expected) - expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] - self.assertEqual(ImageSchema.imageFields, expected) - self.assertEqual(ImageSchema.undefinedImageType, "Undefined") - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "image argument should be pyspark.sql.types.Row; however", - lambda: ImageSchema.toNDArray("a")) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - ValueError, - "image argument should have attributes specified in", - lambda: ImageSchema.toNDArray(Row(a=1))) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "array argument should be numpy.ndarray; however, it got", - lambda: ImageSchema.toImage("a")) - - -class ImageReaderTest2(PySparkTestCase): - - @classmethod - def setUpClass(cls): - super(ImageReaderTest2, cls).setUpClass() - cls.hive_available = True - # Note that here we enable Hive's support. - cls.spark = None - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.tearDownClass() - cls.hive_available = False - except TypeError: - cls.tearDownClass() - cls.hive_available = False - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - super(ImageReaderTest2, cls).tearDownClass() - if cls.spark is not None: - cls.spark.sparkSession.stop() - cls.spark = None - - def test_read_images_multiple_times(self): - # This test case is to check if `ImageSchema.readImages` tries to - # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/origin/kittens' - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - - -class ALSTest(SparkSessionTestCase): - - def test_storage_levels(self): - df = self.spark.createDataFrame( - [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], - ["user", "item", "rating"]) - als = ALS().setMaxIter(1).setRank(1) - # test default params - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") - # test non-default params - als.setIntermediateStorageLevel("MEMORY_ONLY_2") - als.setFinalStorageLevel("DISK_ONLY") - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") - - -class DefaultValuesTests(PySparkTestCase): - """ - Test :py:class:`JavaParams` classes to see if their default Param values match - those in their Scala counterparts. - """ - - def test_java_params(self): - import pyspark.ml.feature - import pyspark.ml.classification - import pyspark.ml.clustering - import pyspark.ml.evaluation - import pyspark.ml.pipeline - import pyspark.ml.recommendation - import pyspark.ml.regression - - modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, - pyspark.ml.regression] - for module in modules: - for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and not name.endswith('Params')\ - and issubclass(cls, JavaParams) and not inspect.isabstract(cls): - # NOTE: disable check_params_exist until there is parity with Scala API - ParamTests.check_params(self, cls(), check_params_exist=False) - - # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel - ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), - check_params_exist=False) - ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), - check_params_exist=False) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), - Row(label=0.0, features=self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -class WrapperTests(MLlibTestCase): - - def test_new_java_array(self): - # test array of strings - str_list = ["a", "b", "c"] - java_class = self.sc._gateway.jvm.java.lang.String - java_array = JavaWrapper._new_java_array(str_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), str_list) - # test array of integers - int_list = [1, 2, 3] - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array(int_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), int_list) - # test array of floats - float_list = [0.1, 0.2, 0.3] - java_class = self.sc._gateway.jvm.java.lang.Double - java_array = JavaWrapper._new_java_array(float_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), float_list) - # test array of bools - bool_list = [False, True, True] - java_class = self.sc._gateway.jvm.java.lang.Boolean - java_array = JavaWrapper._new_java_array(bool_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), bool_list) - # test array of Java DenseVectors - v1 = DenseVector([0.0, 1.0]) - v2 = DenseVector([1.0, 0.0]) - vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] - java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector - java_array = JavaWrapper._new_java_array(vec_java_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) - # test empty array - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array([], java_class) - self.assertEqual(_java2py(self.sc, java_array), []) - - -class ChiSquareTestTests(SparkSessionTestCase): - - def test_chisquaretest(self): - data = [[0, Vectors.dense([0, 1, 2])], - [1, Vectors.dense([1, 1, 1])], - [2, Vectors.dense([2, 1, 0])]] - df = self.spark.createDataFrame(data, ['label', 'feat']) - res = ChiSquareTest.test(df, 'feat', 'label') - # This line is hitting the collect bug described in #17218, commented for now. - # pValues = res.select("degreesOfFreedom").collect()) - self.assertIsInstance(res, DataFrame) - fieldNames = set(field.name for field in res.schema.fields) - expectedFields = ["pValues", "degreesOfFreedom", "statistics"] - self.assertTrue(all(field in fieldNames for field in expectedFields)) - - -class UnaryTransformerTests(SparkSessionTestCase): - - def test_unary_transformer_validate_input_type(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - # should not raise any errors - transformer.validateInputType(DoubleType()) - - with self.assertRaises(TypeError): - # passing the wrong input type should raise an error - transformer.validateInputType(IntegerType()) - - def test_unary_transformer_transform(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - df = self.spark.range(0, 10).toDF('input') - df = df.withColumn("input", df.input.cast(dataType="double")) - - transformed_df = transformer.transform(df) - results = transformed_df.select("input", "output").collect() - - for res in results: - self.assertEqual(res.input + shiftVal, res.output) - - -class EstimatorTest(unittest.TestCase): - - def testDefaultFitMultiple(self): - N = 4 - data = MockDataset() - estimator = MockEstimator() - params = [{estimator.fake: i} for i in range(N)] - modelIter = estimator.fitMultiple(data, params) - indexList = [] - for index, model in modelIter: - self.assertEqual(model.getFake(), index) - indexList.append(index) - self.assertEqual(sorted(indexList), list(range(N))) - - -if __name__ == "__main__": - from pyspark.ml.tests import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/ml/tests/__init__.py b/python/pyspark/ml/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/ml/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py new file mode 100644 index 0000000000000..1a72e124962c8 --- /dev/null +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -0,0 +1,349 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from shutil import rmtree +import sys +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import numpy as np + +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, \ + MultilayerPerceptronClassifier, OneVsRest +from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel +from pyspark.ml.fpm import FPGrowth +from pyspark.ml.linalg import Matrices, Vectors +from pyspark.ml.recommendation import ALS +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class LogisticRegressionTest(SparkSessionTestCase): + + def test_binomial_logistic_regression_with_bound(self): + + df = self.spark.createDataFrame( + [(1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) + + lor = LogisticRegression(regParam=0.01, weightCol="weight", + lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), + upperBoundsOnIntercepts=Vectors.dense(0.0)) + model = lor.fit(df) + self.assertTrue( + np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) + + def test_multinomial_logistic_regression_with_bound(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lor = LogisticRegression(regParam=0.01, + lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), + upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) + model = lor.fit(df) + expected = [[4.593, 4.5516, 9.0099, 12.2904], + [1.0, 8.1093, 7.0, 10.0], + [3.041, 5.0, 8.0, 11.0]] + for i in range(0, len(expected)): + self.assertTrue( + np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) + self.assertTrue( + np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) + + +class MultilayerPerceptronClassifierTest(SparkSessionTestCase): + + def test_raw_and_probability_prediction(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], + blockSize=128, seed=123) + model = mlp.fit(df) + test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() + result = model.transform(test).head() + expected_prediction = 2.0 + expected_probability = [0.0, 0.0, 1.0] + expected_rawPrediction = [57.3955, -124.5462, 67.9943] + self.assertTrue(result.prediction, expected_prediction) + self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) + self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) + + +class OneVsRestTests(SparkSessionTestCase): + + def test_copy(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + ovr1 = ovr.copy({lr.maxIter: 10}) + self.assertEqual(ovr.getClassifier().getMaxIter(), 5) + self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) + model = ovr.fit(df) + model1 = model.copy({model.predictionCol: "indexed"}) + self.assertEqual(model1.getPredictionCol(), "indexed") + + def test_output_columns(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, parallelism=1) + model = ovr.fit(df) + output = model.transform(df) + self.assertEqual(output.columns, ["label", "features", "prediction"]) + + def test_parallelism_doesnt_change_output(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) + modelPar1 = ovrPar1.fit(df) + ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) + modelPar2 = ovrPar2.fit(df) + for i, model in enumerate(modelPar1.models): + self.assertTrue(np.allclose(model.coefficients.toArray(), + modelPar2.models[i].coefficients.toArray(), atol=1E-4)) + self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) + + def test_support_for_weightCol(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + # classifier inherits hasWeightCol + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") + self.assertIsNotNone(ovr.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) + + +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + +class LDATest(SparkSessionTestCase): + + def _compare(self, m1, m2): + """ + Temp method for comparing instances. + TODO: Replace with generic implementation once SPARK-14706 is merged. + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + if m1.isDefined(p): + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + if isinstance(m1, LDAModel): + self.assertEqual(m1.vocabSize(), m2.vocabSize()) + self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) + + def test_persistence(self): + # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. + df = self.spark.createDataFrame([ + [1, Vectors.dense([0.0, 1.0])], + [2, Vectors.sparse(2, {0: 1.0})], + ], ["id", "features"]) + # Fit model + lda = LDA(k=2, seed=1, optimizer="em") + distributedModel = lda.fit(df) + self.assertTrue(distributedModel.isDistributed()) + localModel = distributedModel.toLocal() + self.assertFalse(localModel.isDistributed()) + # Define paths + path = tempfile.mkdtemp() + lda_path = path + "/lda" + dist_model_path = path + "/distLDAModel" + local_model_path = path + "/localLDAModel" + # Test LDA + lda.save(lda_path) + lda2 = LDA.load(lda_path) + self._compare(lda, lda2) + # Test DistributedLDAModel + distributedModel.save(dist_model_path) + distributedModel2 = DistributedLDAModel.load(dist_model_path) + self._compare(distributedModel, distributedModel2) + # Test LocalLDAModel + localModel.save(local_model_path) + localModel2 = LocalLDAModel.load(local_model_path) + self._compare(localModel, localModel2) + # Clean up + try: + rmtree(path) + except OSError: + pass + + +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], + ["antecedent", "consequent", "confidence", "lift"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + +class ALSTest(SparkSessionTestCase): + + def test_storage_levels(self): + df = self.spark.createDataFrame( + [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ["user", "item", "rating"]) + als = ALS().setMaxIter(1).setRank(1) + # test default params + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") + # test non-default params + als.setIntermediateStorageLevel("MEMORY_ONLY_2") + als.setFinalStorageLevel("DISK_ONLY") + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") + + +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + + +class LinearRegressionTest(SparkSessionTestCase): + + def test_linear_regression_with_huber_loss(self): + + data_path = "data/mllib/sample_linear_regression_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lir = LinearRegression(loss="huber", epsilon=2.0) + model = lir.fit(df) + + expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, + 1.2612, -0.333, -0.5694, -0.6311, 0.6053] + expectedIntercept = 0.1607 + expectedScale = 9.758 + + self.assertTrue( + np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) + self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) + self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py new file mode 100644 index 0000000000000..59c45f638dd45 --- /dev/null +++ b/python/pyspark/ml/tests/test_base.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql.types import DoubleType, IntegerType +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \ + SparkSessionTestCase + + +class UnaryTransformerTests(SparkSessionTestCase): + + def test_unary_transformer_validate_input_type(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + # should not raise any errors + transformer.validateInputType(DoubleType()) + + with self.assertRaises(TypeError): + # passing the wrong input type should raise an error + transformer.validateInputType(IntegerType()) + + def test_unary_transformer_transform(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + df = self.spark.range(0, 10).toDF('input') + df = df.withColumn("input", df.input.cast(dataType="double")) + + transformed_df = transformer.transform(df) + results = transformed_df.select("input", "output").collect() + + for res in results: + self.assertEqual(res.input + shiftVal, res.output) + + +class EstimatorTest(unittest.TestCase): + + def testDefaultFitMultiple(self): + N = 4 + data = MockDataset() + estimator = MockEstimator() + params = [{estimator.fake: i} for i in range(N)] + modelIter = estimator.fitMultiple(data, params) + indexList = [] + for index, model in modelIter: + self.assertEqual(model.getFake(), index) + indexList.append(index) + self.assertEqual(sorted(indexList), list(range(N))) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_base import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py new file mode 100644 index 0000000000000..6c3e5c6734509 --- /dev/null +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import numpy as np + +from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class EvaluatorTests(SparkSessionTestCase): + + def test_java_params(self): + """ + This tests a bug fixed by SPARK-18274 which causes multiple copies + of a Params instance in Python to be linked to the same Java instance. + """ + evaluator = RegressionEvaluator(metricName="r2") + df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) + evaluator.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) + evaluator.evaluate(df) + evaluatorCopy.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_evaluation import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py new file mode 100644 index 0000000000000..23f66e73b4820 --- /dev/null +++ b/python/pyspark/ml/tests/test_feature.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.feature import Binarizer, CountVectorizer, CountVectorizerModel, HashingTF, IDF, \ + NGram, RFormula, StopWordsRemover, StringIndexer, StringIndexerModel, VectorSizeHint +from pyspark.ml.linalg import DenseVector, SparseVector, Vectors +from pyspark.sql import Row +from pyspark.testing.utils import QuietTest +from pyspark.testing.mlutils import check_params, SparkSessionTestCase + + +class FeatureTests(SparkSessionTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + dataset = self.spark.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) + # Test that parameters transferred to Python Model + check_params(self, idf0m) + + def test_ngram(self): + dataset = self.spark.createDataFrame([ + Row(input=["a", "b", "c", "d", "e"])]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) + + def test_stopwordsremover(self): + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEqual(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["panda"]) + self.assertEqual(type(stopWordRemover.getStopWords()), list) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a"]) + # with language selection + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + + def test_count_vectorizer_with_binary(self): + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + + def test_rformula_string_indexer_order_type(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") + self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') + transformedDF = rf.fit(df).transform(df) + observed = transformedDF.select("features").collect() + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] + for i in range(0, len(expected)): + self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + + +class HashingTFTest(SparkSessionTestCase): + + def test_apply_binary_term_freqs(self): + + df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 10 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py new file mode 100644 index 0000000000000..dcc7a32c9fd70 --- /dev/null +++ b/python/pyspark/ml/tests/test_image.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import py4j + +from pyspark.ml.image import ImageSchema +from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase +from pyspark.sql import HiveContext, Row +from pyspark.testing.utils import QuietTest + + +class ImageReaderTest(SparkSessionTestCase): + + def test_read_images(self): + data_path = 'data/mllib/images/origin/kittens' + df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + self.assertEqual(df.count(), 4) + first_row = df.take(1)[0][0] + array = ImageSchema.toNDArray(first_row) + self.assertEqual(len(array), first_row[1]) + self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) + self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) + expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} + self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] + self.assertEqual(ImageSchema.imageFields, expected) + self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "image argument should be pyspark.sql.types.Row; however", + lambda: ImageSchema.toNDArray("a")) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + ValueError, + "image argument should have attributes specified in", + lambda: ImageSchema.toNDArray(Row(a=1))) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "array argument should be numpy.ndarray; however, it got", + lambda: ImageSchema.toImage("a")) + + +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True + # Note that here we enable Hive's support. + cls.spark = None + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + cls.hive_available = False + except TypeError: + cls.tearDownClass() + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") + + @classmethod + def tearDownClass(cls): + super(ImageReaderTest2, cls).tearDownClass() + if cls.spark is not None: + cls.spark.sparkSession.stop() + cls.spark = None + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/origin/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_image import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py new file mode 100644 index 0000000000000..76e5386e86125 --- /dev/null +++ b/python/pyspark/ml/tests/test_linalg.py @@ -0,0 +1,392 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import array as pyarray +from numpy import arange, array, array_equal, inf, ones, tile, zeros + +from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ + Vector, VectorUDT, Vectors +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase +from pyspark.sql import Row + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), + Row(label=0.0, features=self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_linalg import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py new file mode 100644 index 0000000000000..1f36d4544ab92 --- /dev/null +++ b/python/pyspark/ml/tests/test_param.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import sys +import array as pyarray +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark import keyword_only +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import KMeans +from pyspark.ml.feature import Binarizer, Bucketizer, ElementwiseProduct, IndexToString, \ + VectorSlicer, Word2Vec +from pyspark.ml.linalg import DenseVector, SparseVector +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase + + +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int(self): + lr = LogisticRegression(maxIter=5.0) + self.assertEqual(lr.getMaxIter(), 5) + self.assertTrue(type(lr.getMaxIter()) == int) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) + + def test_float(self): + lr = LogisticRegression(tol=1) + self.assertEqual(lr.getTol(), 1.0) + self.assertTrue(type(lr.getTol()) == float) + self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) + + def test_vector(self): + ewp = ElementwiseProduct(scalingVec=[1, 3]) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) + ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) + self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) + + def test_list(self): + l = [0, 1] + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), + pyarray.array('l', l), xrange(2), tuple(l)]: + converted = TypeConverters.toList(lst_like) + self.assertEqual(type(converted), list) + self.assertListEqual(converted, l) + + def test_list_int(self): + for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), + SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + pyarray.array('d', [1.0, 2.0])]: + vs = VectorSlicer(indices=indices) + self.assertListEqual(vs.getIndices(), [1, 2]) + self.assertTrue(all([type(v) == int for v in vs.getIndices()])) + self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) + + def test_list_float(self): + b = Bucketizer(splits=[1, 4]) + self.assertEqual(b.getSplits(), [1.0, 4.0]) + self.assertTrue(all([type(v) == float for v in b.getSplits()])) + self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) + + def test_list_string(self): + for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: + idx_to_string = IndexToString(labels=labels) + self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) + self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) + + def test_string(self): + lr = LogisticRegression() + for col in ['features', u'features', np.str_('features')]: + lr.setFeaturesCol(col) + self.assertEqual(lr.getFeaturesCol(), 'features') + self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) + + def test_bool(self): + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + + +class TestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + +class ParamTests(SparkSessionTestCase): + + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + seed = testParams.seed + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter, seed]) + + self.assertTrue(testParams.hasParam(maxIter.name)) + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEqual(testParams.getMaxIter(), 100) + + self.assertTrue(testParams.hasParam(inputCol.name)) + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + + "set raises an error for a non-member parameter.", + typeConverter=TypeConverters.toString) + with self.assertRaises(ValueError): + testParams.set(otherParam, "value") + + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + + self.assertEqual( + testParams.explainParams(), + "\n".join(["inputCol: input column name. (undefined)", + "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", + "seed: random seed. (default: 41, current: 43)"])) + + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + def test_param_property_error(self): + param_store = HasThrowableProperty() + self.assertRaises(RuntimeError, lambda: param_store.test_property) + params = param_store.params # should not invoke the property 'test_property' + self.assertEqual(len(params), 1) + + def test_word2vec_param(self): + model = Word2Vec().setWindowSize(6) + # Check windowSize is set properly + self.assertEqual(model.getWindowSize(), 6) + + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + + +class DefaultValuesTests(PySparkTestCase): + """ + Test :py:class:`JavaParams` classes to see if their default Param values match + those in their Scala counterparts. + """ + + def test_java_params(self): + import pyspark.ml.feature + import pyspark.ml.classification + import pyspark.ml.clustering + import pyspark.ml.evaluation + import pyspark.ml.pipeline + import pyspark.ml.recommendation + import pyspark.ml.regression + + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] + for module in modules: + for name, cls in inspect.getmembers(module, inspect.isclass): + if not name.endswith('Model') and not name.endswith('Params') \ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): + # NOTE: disable check_params_exist until there is parity with Scala API + check_params(self, cls(), check_params_exist=False) + + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel + check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_param import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py new file mode 100644 index 0000000000000..b5a2e16df5532 --- /dev/null +++ b/python/pyspark/ml/tests/test_persistence.py @@ -0,0 +1,369 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from shutil import rmtree +import sys +import tempfile +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml import Transformer +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ + OneVsRestModel +from pyspark.ml.feature import Binarizer, HashingTF, PCA +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase + + +class PersistenceTest(SparkSessionTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + + def test_logistic_regression(self): + lr = LogisticRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/logreg" + lr.save(lr_path) + lr2 = LogisticRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LogisticRegression instance uid (%s) " + "did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LogisticRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self._compare_params(m1, m2, p) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + pl = Pipeline(stages=[tf, pca]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_default_read_write(self): + temp_path = tempfile.mkdtemp() + + lr = LogisticRegression() + lr.setMaxIter(50) + lr.setThreshold(.75) + writer = DefaultParamsWriter(lr) + + savePath = temp_path + "/lr" + writer.save(savePath) + + reader = DefaultParamsReadable.read() + lr2 = reader.load(savePath) + + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) + + # test overwrite + lr.setThreshold(.8) + writer.overwrite().save(savePath) + + reader = DefaultParamsReadable.read() + lr3 = reader.load(savePath) + + self.assertEqual(lr.uid, lr3.uid) + self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_persistence import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py new file mode 100644 index 0000000000000..31ef02c2e601f --- /dev/null +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml.pipeline import Pipeline +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + model0, transformer1, model2, transformer3 = pipeline_model.stages + self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + def test_identity_pipeline(self): + dataset = MockDataset() + + def doTransform(pipeline): + pipeline_model = pipeline.fit(dataset) + return pipeline_model.transform(dataset) + # check that empty pipeline did not perform any transformation + self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) + # check that failure to set stages param will raise KeyError for missing param + self.assertRaises(KeyError, lambda: doTransform(Pipeline())) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_pipeline import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py new file mode 100644 index 0000000000000..bdc4853bc05c2 --- /dev/null +++ b/python/pyspark/ml/tests/test_stat.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml.linalg import Vectors +from pyspark.ml.stat import ChiSquareTest +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py new file mode 100644 index 0000000000000..d5464f7be6372 --- /dev/null +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans +from pyspark.ml.linalg import Vectors +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class TrainingSummaryTest(SparkSessionTestCase): + + def test_linear_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + self.assertEqual(s.degreesOfFreedom, 1) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class LinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_glr_summary(self): + from pyspark.ml.linalg import Vectors + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", + fitIntercept=False) + model = glr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.numInstances, 2) + self.assertTrue(isinstance(s.residuals(), DataFrame)) + self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + self.assertEqual(s.degreesOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedomNull, 2) + self.assertEqual(s.rank, 1) + self.assertTrue(isinstance(s.solver, basestring)) + self.assertTrue(isinstance(s.aic, float)) + self.assertTrue(isinstance(s.deviance, float)) + self.assertTrue(isinstance(s.nullDeviance, float)) + self.assertTrue(isinstance(s.dispersion, float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class GeneralizedLinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.deviance, s.deviance) + + def test_binary_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + + def test_multiclass_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], [])), + (2.0, 2.0, Vectors.dense(2.0)), + (2.0, 2.0, Vectors.dense(1.9))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) + + def test_kmeans_summary(self): + data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=2, seed=1) + model = kmeans.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_training_summary import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py new file mode 100644 index 0000000000000..af00d1de7ab6a --- /dev/null +++ b/python/pyspark/ml/tests/test_tuning.py @@ -0,0 +1,552 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import tempfile +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml import Estimator, Model +from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest +from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ + MulticlassClassificationEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Param, Params +from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \ + TrainValidationSplit, TrainValidationSplitModel +from pyspark.sql.functions import rand +from pyspark.testing.mlutils import SparkSessionTestCase + + +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(SparkSessionTestCase): + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvCopied = cv.copy() + self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) + + cvModel = cv.fit(dataset) + cvModelCopied = cvModel.copy() + for index in range(len(cvModel.avgMetrics)): + self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) + < 0.0001) + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for CrossValidator will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + lrModel = cvModel.bestModel + + cvModelPath = temp_path + "/cvModel" + lrModel.save(cvModelPath) + loadedLrModel = LogisticRegressionModel.load(cvModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + + def test_save_load_nested_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + + originalParamMap = cv.getEstimatorParamMaps() + loadedParamMap = loadedCV.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + +class TrainValidationSplitTests(SparkSessionTestCase): + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(0.0, min(validationMetrics)) + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(1.0, max(validationMetrics)) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + lrModel = tvsModel.bestModel + + tvsModelPath = temp_path + "/tvsModel" + lrModel.save(tvsModelPath) + loadedLrModel = LogisticRegressionModel.load(tvsModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + + def test_save_load_nested_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + + originalParamMap = tvs.getEstimatorParamMaps() + loadedParamMap = loadedTvs.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsCopied = tvs.copy() + tvsModelCopied = tvsModel.copy() + + self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, + "Copied TrainValidationSplit has the same uid of Estimator") + + self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) + self.assertEqual(len(tvsModel.validationMetrics), + len(tvsModelCopied.validationMetrics), + "Copied validationMetrics has the same size of the original") + for index in range(len(tvsModel.validationMetrics)): + self.assertEqual(tvsModel.validationMetrics[index], + tvsModelCopied.validationMetrics[index]) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_tuning import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py new file mode 100644 index 0000000000000..4326d8e060dd7 --- /dev/null +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import py4j + +from pyspark.ml.linalg import DenseVector, Vectors +from pyspark.ml.regression import LinearRegression +from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper +from pyspark.testing.mllibutils import MLlibTestCase +from pyspark.testing.mlutils import SparkSessionTestCase + + +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_wrapper import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py new file mode 100644 index 0000000000000..12bf650a28ee1 --- /dev/null +++ b/python/pyspark/testing/mlutils.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable +from pyspark.ml.wrapper import _java2py +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import DoubleType +from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase + + +def check_params(test_self, py_stage, check_params_exist=True): + """ + Checks common requirements for Params.params: + - set of params exist in Java and Python and are ordered by names + - param parent has the same UID as the object's UID + - default param value from Java matches value in Python + - optionally check if all params from Java also exist in Python + """ + py_stage_str = "%s %s" % (type(py_stage), py_stage) + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) + if check_params_exist: + param_names = [p.name for p in py_stage.params] + java_params = list(java_stage.params()) + java_param_names = [jp.name() for jp in java_params] + test_self.assertEqual( + param_names, sorted(java_param_names), + "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" + % (py_stage_str, java_param_names, param_names)) + for p in py_stage.params: + test_self.assertEqual(p.parent, py_stage.uid) + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + test_self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + continue # Random seeds between Spark and PySpark are different + java_default = _java2py(test_self.sc, + java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + # equality test for NaN is always False + if isinstance(java_default, float) and np.isnan(java_default): + java_default = "NaN" + py_default = "NaN" if np.isnan(py_default) else "not NaN" + test_self.assertEqual( + java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + + +class SparkSessionTestCase(PySparkTestCase): + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.stop() + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): + + def __init__(self): + super(MockTransformer, self).__init__() + self.dataset_index = None + + def _transform(self, dataset): + self.dataset_index = dataset.index + dataset.index += 1 + return dataset + + +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): + + shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + "data in a DataFrame", + typeConverter=TypeConverters.toFloat) + + def __init__(self, shiftVal=1): + super(MockUnaryTransformer, self).__init__() + self._setDefault(shift=1) + self._set(shift=shiftVal) + + def getShift(self): + return self.getOrDefault(self.shift) + + def setShift(self, shift): + self._set(shift=shift) + + def createTransformFunc(self): + shiftVal = self.getShift() + return lambda x: x + shiftVal + + def outputDataType(self): + return DoubleType() + + def validateInputType(self, inputType): + if inputType != DoubleType(): + raise TypeError("Bad input type: {}. ".format(inputType) + + "Requires Double.") + + +class MockEstimator(Estimator, HasFake): + + def __init__(self): + super(MockEstimator, self).__init__() + self.dataset_index = None + + def _fit(self, dataset): + self.dataset_index = dataset.index + model = MockModel() + self._copyValues(model) + return model + + +class MockModel(MockTransformer, Model, HasFake): + pass From bbbdaa82a4f4fc7a84be6641518264d9bb7bde2b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Nov 2018 09:22:32 +0800 Subject: [PATCH 069/145] [SPARK-26105][PYTHON] Clean unittest2 imports up that were added for Python 2.6 before ## What changes were proposed in this pull request? Currently, some of PySpark tests sill assume the tests could be ran in Python 2.6 by importing `unittest2`. For instance: ```python if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest except ImportError: sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') sys.exit(1) else: import unittest ``` While I am here, I removed some of unused imports and reordered imports per PEP 8. We officially dropped Python 2.6 support a while ago and started to discuss about Python 2 drop. It's better to remove them out. ## How was this patch tested? Manually tests, and existing tests via Jenkins. Closes #23077 from HyukjinKwon/SPARK-26105. Lead-authored-by: hyukjinkwon Co-authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- python/pyspark/ml/tests/test_algorithms.py | 11 +---------- python/pyspark/ml/tests/test_base.py | 10 +--------- python/pyspark/ml/tests/test_evaluation.py | 10 +--------- python/pyspark/ml/tests/test_feature.py | 9 +-------- python/pyspark/ml/tests/test_image.py | 10 +--------- python/pyspark/ml/tests/test_linalg.py | 12 ++---------- python/pyspark/ml/tests/test_param.py | 16 +++++----------- python/pyspark/ml/tests/test_persistence.py | 10 +--------- python/pyspark/ml/tests/test_pipeline.py | 10 +--------- python/pyspark/ml/tests/test_stat.py | 10 +--------- python/pyspark/ml/tests/test_training_summary.py | 9 +-------- python/pyspark/ml/tests/test_tuning.py | 10 +--------- python/pyspark/ml/tests/test_wrapper.py | 10 +--------- python/pyspark/mllib/tests/test_algorithms.py | 13 +------------ python/pyspark/mllib/tests/test_feature.py | 11 +---------- python/pyspark/mllib/tests/test_linalg.py | 11 +---------- python/pyspark/mllib/tests/test_stat.py | 11 +---------- .../mllib/tests/test_streaming_algorithms.py | 11 +---------- python/pyspark/mllib/tests/test_util.py | 15 ++------------- python/pyspark/testing/mllibutils.py | 11 +---------- 20 files changed, 26 insertions(+), 194 deletions(-) diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index 1a72e124962c8..516bb563402e0 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -16,17 +16,8 @@ # from shutil import rmtree -import sys import tempfile - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import numpy as np diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py index 59c45f638dd45..31e3deb53046c 100644 --- a/python/pyspark/ml/tests/test_base.py +++ b/python/pyspark/ml/tests/test_base.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.sql.types import DoubleType, IntegerType from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \ diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index 6c3e5c6734509..5438455a6f756 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import numpy as np diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 23f66e73b4820..325feaba66957 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -17,14 +17,7 @@ # import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest if sys.version > '3': basestring = str diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py index dcc7a32c9fd70..4c280a4a67894 100644 --- a/python/pyspark/ml/tests/test_image.py +++ b/python/pyspark/ml/tests/test_image.py @@ -14,15 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import py4j diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 76e5386e86125..71cad5d7f5ad7 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -15,17 +15,9 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - +import unittest import array as pyarray + from numpy import arange, array, array_equal, inf, ones, tile, zeros from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 1f36d4544ab92..17c1b0bf65dde 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -19,17 +19,7 @@ import inspect import sys import array as pyarray -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version > '3': - xrange = range +import unittest import numpy as np @@ -45,6 +35,10 @@ from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase +if sys.version > '3': + xrange = range + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index b5a2e16df5532..34d687039ab34 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -17,16 +17,8 @@ import json from shutil import rmtree -import sys import tempfile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml import Transformer from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index 31ef02c2e601f..9e3e6c4a75d7a 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -14,15 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml.pipeline import Pipeline from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py index bdc4853bc05c2..11aaf2e8083e1 100644 --- a/python/pyspark/ml/tests/test_stat.py +++ b/python/pyspark/ml/tests/test_stat.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml.linalg import Vectors from pyspark.ml.stat import ChiSquareTest diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index d5464f7be6372..8575111c84025 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -16,14 +16,7 @@ # import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest if sys.version > '3': basestring = str diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index af00d1de7ab6a..39bb921aaf43d 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -15,16 +15,8 @@ # limitations under the License. # -import sys import tempfile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml import Estimator, Model from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 4326d8e060dd7..ae672a00c1dc1 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import py4j diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 8a3454144a115..cc3b64b1cb284 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -16,27 +16,16 @@ # import os -import sys import tempfile from shutil import rmtree +import unittest from numpy import array, array_equal - from py4j.protocol import Py4JJavaError -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint -from pyspark.sql.utils import IllegalArgumentException from pyspark.testing.mllibutils import make_serializer, MLlibTestCase diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py index 48ed810fa6fcb..3da841c408558 100644 --- a/python/pyspark/mllib/tests/test_feature.py +++ b/python/pyspark/mllib/tests/test_feature.py @@ -15,20 +15,11 @@ # limitations under the License. # -import sys from math import sqrt +import unittest from numpy import array, random, exp, abs, tile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 550e32a9af024..d0ebd9bc3db79 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -17,18 +17,9 @@ import sys import array as pyarray +import unittest from numpy import array, array_equal, zeros, arange, tile, ones, inf -from numpy import sum as array_sum - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest import pyspark.ml.linalg as newlinalg from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py index 5e74087d8fa7b..f23ae291d317a 100644 --- a/python/pyspark/mllib/tests/test_stat.py +++ b/python/pyspark/mllib/tests/test_stat.py @@ -15,20 +15,11 @@ # limitations under the License. # -import sys import array as pyarray +import unittest from numpy import array -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.random import RandomRDDs diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index ba95855fd4f00..4bc8904acd31c 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -15,21 +15,12 @@ # limitations under the License. # -import sys from time import time, sleep +import unittest from numpy import array, random, exp, dot, all, mean, abs from numpy import sum as array_sum -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark import SparkContext from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py index c924eba80484c..e95716278f122 100644 --- a/python/pyspark/mllib/tests/test_util.py +++ b/python/pyspark/mllib/tests/test_util.py @@ -16,25 +16,14 @@ # import os -import sys import tempfile - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils -from pyspark.mllib.linalg import SparseVector, DenseVector, SparseMatrix, Vectors +from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.stat import Statistics from pyspark.testing.mllibutils import MLlibTestCase diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py index 9248182658f84..25f1bba8d37ac 100644 --- a/python/pyspark/testing/mllibutils.py +++ b/python/pyspark/testing/mllibutils.py @@ -15,16 +15,7 @@ # limitations under the License. # -import sys - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark import SparkContext from pyspark.serializers import PickleSerializer From 630e25e35506c02a0b1e202ef82b1b0f69e50966 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 08:06:33 -0600 Subject: [PATCH 070/145] [SPARK-26026][BUILD] Published Scaladoc jars missing from Maven Central ## What changes were proposed in this pull request? This restores scaladoc artifact generation, which got dropped with the Scala 2.12 update. The change looks large, but is almost all due to needing to make the InterfaceStability annotations top-level classes (i.e. `InterfaceStability.Stable` -> `Stable`), unfortunately. A few inner class references had to be qualified too. Lots of scaladoc warnings now reappear. We can choose to disable generation by default and enable for releases, later. ## How was this patch tested? N/A; build runs scaladoc now. Closes #23069 from srowen/SPARK-26026. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 2 +- .../spark/network/protocol/OneWayMessage.java | 2 +- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 2 +- .../spark/network/protocol/RpcResponse.java | 2 +- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 2 +- .../spark/network/protocol/UploadStream.java | 2 +- .../spark/network/sasl/SaslMessage.java | 3 +- .../network/shuffle/RetryingBlockFetcher.java | 2 +- .../org/apache/spark/annotation/Evolving.java | 30 +++++++++ .../spark/annotation/InterfaceStability.java | 58 ----------------- .../org/apache/spark/annotation/Stable.java | 31 ++++++++++ .../org/apache/spark/annotation/Unstable.java | 30 +++++++++ .../kinesis/KinesisInputDStream.scala | 6 +- .../kinesis/SparkAWSCredentials.scala | 9 ++- .../spark/launcher/AbstractAppHandle.java | 12 ++-- .../org/apache/spark/ml/util/ReadWrite.scala | 10 +-- pom.xml | 8 ++- .../java/org/apache/spark/sql/RowFactory.java | 4 +- .../execution/UnsafeExternalRowSorter.java | 10 +-- .../sql/streaming/GroupStateTimeout.java | 4 +- .../spark/sql/streaming/OutputMode.java | 4 +- .../org/apache/spark/sql/types/DataTypes.java | 4 +- .../spark/sql/types/SQLUserDefinedType.java | 4 +- .../apache/spark/sql/AnalysisException.scala | 5 +- .../scala/org/apache/spark/sql/Encoder.scala | 5 +- .../scala/org/apache/spark/sql/Encoders.scala | 4 +- .../main/scala/org/apache/spark/sql/Row.scala | 6 +- .../spark/sql/types/AbstractDataType.scala | 4 +- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../apache/spark/sql/types/BinaryType.scala | 7 +-- .../apache/spark/sql/types/BooleanType.scala | 7 +-- .../org/apache/spark/sql/types/ByteType.scala | 6 +- .../sql/types/CalendarIntervalType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 6 +- .../org/apache/spark/sql/types/DateType.scala | 6 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../apache/spark/sql/types/DecimalType.scala | 7 +-- .../apache/spark/sql/types/DoubleType.scala | 6 +- .../apache/spark/sql/types/FloatType.scala | 6 +- .../apache/spark/sql/types/IntegerType.scala | 6 +- .../org/apache/spark/sql/types/LongType.scala | 6 +- .../org/apache/spark/sql/types/MapType.scala | 6 +- .../org/apache/spark/sql/types/Metadata.scala | 8 +-- .../org/apache/spark/sql/types/NullType.scala | 7 +-- .../apache/spark/sql/types/ObjectType.scala | 6 +- .../apache/spark/sql/types/ShortType.scala | 6 +- .../apache/spark/sql/types/StringType.scala | 6 +- .../apache/spark/sql/types/StructField.scala | 4 +- .../apache/spark/sql/types/StructType.scala | 8 +-- .../spark/sql/types/TimestampType.scala | 6 +- .../FlatMapGroupsWithStateFunction.java | 4 +- .../function/MapGroupsWithStateFunction.java | 4 +- .../java/org/apache/spark/sql/SaveMode.java | 4 +- .../org/apache/spark/sql/api/java/UDF0.java | 4 +- .../org/apache/spark/sql/api/java/UDF1.java | 4 +- .../org/apache/spark/sql/api/java/UDF10.java | 4 +- .../org/apache/spark/sql/api/java/UDF11.java | 4 +- .../org/apache/spark/sql/api/java/UDF12.java | 4 +- .../org/apache/spark/sql/api/java/UDF13.java | 4 +- .../org/apache/spark/sql/api/java/UDF14.java | 4 +- .../org/apache/spark/sql/api/java/UDF15.java | 4 +- .../org/apache/spark/sql/api/java/UDF16.java | 4 +- .../org/apache/spark/sql/api/java/UDF17.java | 4 +- .../org/apache/spark/sql/api/java/UDF18.java | 4 +- .../org/apache/spark/sql/api/java/UDF19.java | 4 +- .../org/apache/spark/sql/api/java/UDF2.java | 4 +- .../org/apache/spark/sql/api/java/UDF20.java | 4 +- .../org/apache/spark/sql/api/java/UDF21.java | 4 +- .../org/apache/spark/sql/api/java/UDF22.java | 4 +- .../org/apache/spark/sql/api/java/UDF3.java | 4 +- .../org/apache/spark/sql/api/java/UDF4.java | 4 +- .../org/apache/spark/sql/api/java/UDF5.java | 4 +- .../org/apache/spark/sql/api/java/UDF6.java | 4 +- .../org/apache/spark/sql/api/java/UDF7.java | 4 +- .../org/apache/spark/sql/api/java/UDF8.java | 4 +- .../org/apache/spark/sql/api/java/UDF9.java | 4 +- ...emaColumnConvertNotSupportedException.java | 4 +- .../spark/sql/expressions/javalang/typed.java | 4 +- .../sources/v2/BatchReadSupportProvider.java | 4 +- .../sources/v2/BatchWriteSupportProvider.java | 4 +- .../v2/ContinuousReadSupportProvider.java | 4 +- .../sql/sources/v2/DataSourceOptions.java | 4 +- .../spark/sql/sources/v2/DataSourceV2.java | 4 +- .../v2/MicroBatchReadSupportProvider.java | 4 +- .../sql/sources/v2/SessionConfigSupport.java | 4 +- .../v2/StreamingWriteSupportProvider.java | 4 +- .../sources/v2/reader/BatchReadSupport.java | 4 +- .../sql/sources/v2/reader/InputPartition.java | 4 +- .../sources/v2/reader/PartitionReader.java | 4 +- .../v2/reader/PartitionReaderFactory.java | 4 +- .../sql/sources/v2/reader/ReadSupport.java | 4 +- .../sql/sources/v2/reader/ScanConfig.java | 4 +- .../sources/v2/reader/ScanConfigBuilder.java | 4 +- .../sql/sources/v2/reader/Statistics.java | 4 +- .../v2/reader/SupportsPushDownFilters.java | 4 +- .../SupportsPushDownRequiredColumns.java | 4 +- .../v2/reader/SupportsReportPartitioning.java | 4 +- .../v2/reader/SupportsReportStatistics.java | 4 +- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 4 +- .../v2/reader/partitioning/Partitioning.java | 4 +- .../streaming/ContinuousPartitionReader.java | 4 +- .../ContinuousPartitionReaderFactory.java | 4 +- .../streaming/ContinuousReadSupport.java | 4 +- .../streaming/MicroBatchReadSupport.java | 4 +- .../sources/v2/reader/streaming/Offset.java | 4 +- .../v2/reader/streaming/PartitionOffset.java | 4 +- .../sources/v2/writer/BatchWriteSupport.java | 4 +- .../sql/sources/v2/writer/DataWriter.java | 4 +- .../sources/v2/writer/DataWriterFactory.java | 4 +- .../v2/writer/WriterCommitMessage.java | 4 +- .../streaming/StreamingDataWriterFactory.java | 4 +- .../streaming/StreamingWriteSupport.java | 4 +- .../apache/spark/sql/streaming/Trigger.java | 4 +- .../sql/vectorized/ArrowColumnVector.java | 4 +- .../spark/sql/vectorized/ColumnVector.java | 4 +- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarBatch.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 4 +- .../scala/org/apache/spark/sql/Column.scala | 8 +-- .../spark/sql/DataFrameNaFunctions.scala | 5 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../spark/sql/DataFrameStatFunctions.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 62 +++++++++---------- .../org/apache/spark/sql/DatasetHolder.scala | 4 +- .../spark/sql/ExperimentalMethods.scala | 4 +- .../org/apache/spark/sql/ForeachWriter.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 16 ++--- .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../org/apache/spark/sql/RuntimeConfig.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 34 +++++----- .../org/apache/spark/sql/SQLImplicits.scala | 4 +- .../org/apache/spark/sql/SparkSession.scala | 48 +++++++------- .../spark/sql/SparkSessionExtensions.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 4 +- .../apache/spark/sql/catalog/Catalog.scala | 16 ++--- .../apache/spark/sql/catalog/interface.scala | 10 +-- .../sql/execution/streaming/Triggers.scala | 4 +- .../continuous/ContinuousTrigger.scala | 6 +- .../spark/sql/expressions/Aggregator.scala | 6 +- .../sql/expressions/UserDefinedFunction.scala | 4 +- .../apache/spark/sql/expressions/Window.scala | 6 +- .../spark/sql/expressions/WindowSpec.scala | 4 +- .../sql/expressions/scalalang/typed.scala | 4 +- .../apache/spark/sql/expressions/udaf.scala | 6 +- .../org/apache/spark/sql/functions.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 4 +- .../spark/sql/internal/SessionState.scala | 6 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 9 ++- .../scala/org/apache/spark/sql/package.scala | 4 +- .../apache/spark/sql/sources/filters.scala | 34 +++++----- .../apache/spark/sql/sources/interfaces.scala | 26 ++++---- .../sql/streaming/DataStreamReader.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 8 +-- .../spark/sql/streaming/GroupState.scala | 5 +- .../spark/sql/streaming/ProcessingTime.scala | 6 +- .../spark/sql/streaming/StreamingQuery.scala | 4 +- .../streaming/StreamingQueryException.scala | 4 +- .../streaming/StreamingQueryListener.scala | 14 ++--- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../sql/streaming/StreamingQueryStatus.scala | 4 +- .../apache/spark/sql/streaming/progress.scala | 10 +-- .../sql/util/QueryExecutionListener.scala | 6 +- .../hive/service/cli/thrift/TColumn.java | 2 +- .../hive/service/cli/thrift/TColumnValue.java | 2 +- .../service/cli/thrift/TGetInfoValue.java | 2 +- .../hive/service/cli/thrift/TTypeEntry.java | 2 +- .../cli/thrift/TTypeQualifierValue.java | 2 +- .../apache/hive/service/AbstractService.java | 8 +-- .../apache/hive/service/FilterService.java | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 4 +- 177 files changed, 590 insertions(+), 563 deletions(-) create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Evolving.java delete mode 100644 common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Stable.java create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Unstable.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 7b28a9a969486..a7afbfa8621c8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -33,7 +33,7 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { } @Override - public Type type() { return Type.ChunkFetchFailure; } + public Message.Type type() { return Type.ChunkFetchFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 26d063feb5fe3..fe54fcc50dc86 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -32,7 +32,7 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) { } @Override - public Type type() { return Type.ChunkFetchRequest; } + public Message.Type type() { return Type.ChunkFetchRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..d5c9a9b3202fb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -39,7 +39,7 @@ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { } @Override - public Type type() { return Type.ChunkFetchSuccess; } + public Message.Type type() { return Type.ChunkFetchSuccess; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index f7ffb1bd49bb6..1632fb9e03687 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -34,7 +34,7 @@ public OneWayMessage(ManagedBuffer body) { } @Override - public Type type() { return Type.OneWayMessage; } + public Message.Type type() { return Type.OneWayMessage; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index a76624ef5dc96..61061903de23f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -31,7 +31,7 @@ public RpcFailure(long requestId, String errorString) { } @Override - public Type type() { return Type.RpcFailure; } + public Message.Type type() { return Type.RpcFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 2b30920f0598d..cc1bb95d2d566 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -38,7 +38,7 @@ public RpcRequest(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcRequest; } + public Message.Type type() { return Type.RpcRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index d73014ecd8506..c03291e9c0b23 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -33,7 +33,7 @@ public RpcResponse(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcResponse; } + public Message.Type type() { return Type.RpcResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index 258ef81c6783d..68fcfa7748611 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -33,7 +33,7 @@ public StreamFailure(String streamId, String error) { } @Override - public Type type() { return Type.StreamFailure; } + public Message.Type type() { return Type.StreamFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index dc183c043ed9a..1b135af752bd8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -34,7 +34,7 @@ public StreamRequest(String streamId) { } @Override - public Type type() { return Type.StreamRequest; } + public Message.Type type() { return Type.StreamRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 50b811604b84b..568108c4fe5e8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -40,7 +40,7 @@ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { } @Override - public Type type() { return Type.StreamResponse; } + public Message.Type type() { return Type.StreamResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index fa1d26e76b852..7d21151e01074 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -52,7 +52,7 @@ private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { } @Override - public Type type() { return Type.UploadStream; } + public Message.Type type() { return Type.UploadStream; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 7331c2b481fb1..1b03300d948e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -23,6 +23,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.protocol.AbstractMessage; +import org.apache.spark.network.protocol.Message; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -46,7 +47,7 @@ class SaslMessage extends AbstractMessage { } @Override - public Type type() { return Type.User; } + public Message.Type type() { return Type.User; } @Override public int encodedLength() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index f309dda8afca6..6bf3da94030d4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -101,7 +101,7 @@ void createAndStart(String[] blockIds, BlockFetchingListener listener) public RetryingBlockFetcher( TransportConf conf, - BlockFetchStarter fetchStarter, + RetryingBlockFetcher.BlockFetchStarter fetchStarter, String[] blockIds, BlockFetchingListener listener) { this.fetchStarter = fetchStarter; diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java new file mode 100644 index 0000000000000..87e8948f204ff --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. + * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Evolving {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java deleted file mode 100644 index 02bcec737e80e..0000000000000 --- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.annotation; - -import java.lang.annotation.*; - -/** - * Annotation to inform users of how much to rely on a particular package, - * class or method not changing over time. - */ -public class InterfaceStability { - - /** - * Stable APIs that retain source and binary compatibility within a major release. - * These interfaces can change from one major release to another major release - * (e.g. from 1.0 to 2.0). - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Stable {}; - - /** - * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. - * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Evolving {}; - - /** - * Unstable APIs, with no guarantee on stability. - * Classes that are unannotated are considered Unstable. - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Unstable {}; -} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Stable.java b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java new file mode 100644 index 0000000000000..b198bfbe91e10 --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Stable APIs that retain source and binary compatibility within a major release. + * These interfaces can change from one major release to another major release + * (e.g. from 1.0 to 2.0). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Stable {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java new file mode 100644 index 0000000000000..88ee72125b23f --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Unstable APIs, with no guarantee on stability. + * Classes that are unannotated are considered Unstable. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Unstable {} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 1ffec01df9f00..d4a428f45c110 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} @@ -84,14 +84,14 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } } -@InterfaceStability.Evolving +@Evolving object KinesisInputDStream { /** * Builder for [[KinesisInputDStream]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { // Required params private var streamingContext: Option[StreamingContext] = None diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala index 9facfe8ff2b0f..dcb60b21d9851 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -14,13 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.streaming.kinesis -import scala.collection.JavaConverters._ +package org.apache.spark.streaming.kinesis import com.amazonaws.auth._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging /** @@ -84,14 +83,14 @@ private[kinesis] final case class STSCredentials( } } -@InterfaceStability.Evolving +@Evolving object SparkAWSCredentials { /** * Builder for [[SparkAWSCredentials]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { private var basicCreds: Option[BasicCredentials] = None private var stsCreds: Option[STSCredentials] = None diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 9cbebdaeb33d3..0999cbd216871 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -31,8 +31,8 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; private LauncherServer.ServerConnection connection; - private List listeners; - private AtomicReference state; + private List listeners; + private AtomicReference state; private volatile String appId; private volatile boolean disposed; @@ -42,7 +42,7 @@ protected AbstractAppHandle(LauncherServer server) { } @Override - public synchronized void addListener(Listener l) { + public synchronized void addListener(SparkAppHandle.Listener l) { if (listeners == null) { listeners = new CopyOnWriteArrayList<>(); } @@ -50,7 +50,7 @@ public synchronized void addListener(Listener l) { } @Override - public State getState() { + public SparkAppHandle.State getState() { return state.get(); } @@ -120,11 +120,11 @@ synchronized void dispose() { } } - void setState(State s) { + void setState(SparkAppHandle.State s) { setState(s, false); } - void setState(State s, boolean force) { + void setState(SparkAppHandle.State s, boolean force) { if (force) { state.set(s); fireEvent(false); diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d985f8ca1ecc7..fbc7be25a5640 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -84,7 +84,7 @@ private[util] sealed trait BaseReadWrite { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLWriterFormat { /** @@ -108,7 +108,7 @@ trait MLWriterFormat { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLFormatRegister extends MLWriterFormat { /** @@ -208,7 +208,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { /** * A ML Writer which delegates based on the requested format. */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -291,7 +291,7 @@ trait MLWritable { * Trait for classes that provide `GeneralMLWriter`. */ @Since("2.4.0") -@InterfaceStability.Unstable +@Unstable trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. diff --git a/pom.xml b/pom.xml index 59e3d0fa772b4..fcec295eee128 100644 --- a/pom.xml +++ b/pom.xml @@ -2016,7 +2016,6 @@ net.alchim31.maven scala-maven-plugin - 3.4.4 @@ -2037,6 +2036,13 @@ testCompile + + attach-scaladocs + verify + + doc-jar + + ${scala.version} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java index 2ce1fdcbf56ae..0258e66ffb6e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; import org.apache.spark.sql.catalyst.expressions.GenericRow; /** @@ -25,7 +25,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class RowFactory { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 1b2f5eee5ccdd..5395e4035e680 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -50,7 +50,7 @@ public final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final PrefixComputer prefixComputer; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; public abstract static class PrefixComputer { @@ -74,7 +74,7 @@ public static UnsafeExternalRowSorter createWithRecordComparator( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, @@ -85,7 +85,7 @@ public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { Supplier recordComparatorSupplier = @@ -98,9 +98,9 @@ private UnsafeExternalRowSorter( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, - boolean canUseRadixSort) throws IOException { + boolean canUseRadixSort) { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index 5f1032d1229da..5f6a46f2b8e89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.streaming; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.plans.logical.*; /** @@ -29,7 +29,7 @@ * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class GroupStateTimeout { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 470c128ee6c3d..a3d72a1f5d49f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** @@ -26,7 +26,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class OutputMode { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 0f8570fe470bd..d786374f69e20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -19,7 +19,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * To get/create specific data type, users should use singleton objects and factory methods @@ -27,7 +27,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class DataTypes { /** * Gets the StringType object. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1290614a3207d..a54398324fc66 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -20,7 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * ::DeveloperApi:: @@ -31,7 +31,7 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) -@InterfaceStability.Evolving +@Evolving public @interface SQLUserDefinedType { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 50ee6cd4085ea..f5c87677ab9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - /** * Thrown when a query fails to analyze, usually because the query itself is invalid. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 7b02317b8538f..9853a4fcc2f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.types._ - /** * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -67,7 +66,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving @implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + "classes) are supported by importing spark.implicits._ Support for serializing other types " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 8a30c81912fe9..42b865c027205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving object Encoders { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 180c2d130074e..e12bf9616e2de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: @@ -124,7 +124,7 @@ object Row { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait Row extends Serializable { /** Number of elements in the Row. */ def size: Int = length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index c43cc748655e8..5367ce2af8e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -134,7 +134,7 @@ object AtomicType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 58c75b5dc7a35..7465569868f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -21,7 +21,7 @@ import scala.math.Ordering import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.ArrayData /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object ArrayType extends AbstractDataType { /** * Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. @@ -60,7 +60,7 @@ object ArrayType extends AbstractDataType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { /** No-arg constructor for kryo. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index 032d6b54aeb79..cc8b3e6e399a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.TypeUtils - /** * The data type representing `Array[Byte]` values. * Please use the singleton `DataTypes.BinaryType`. */ -@InterfaceStability.Stable +@Stable class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. @@ -55,5 +54,5 @@ class BinaryType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 63f354d2243cf..5e3de71caa37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `Boolean` values. Please use the singleton `DataTypes.BooleanType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. @@ -48,5 +47,5 @@ class BooleanType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 5854c3f5ba116..9d400eefc0f8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Byte` values. Please use the singleton `DataTypes.ByteType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. @@ -52,5 +52,5 @@ class ByteType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 2342036a57460..8e297874a0d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing calendar time intervals. The calendar time interval is stored @@ -29,7 +29,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 @@ -40,5 +40,5 @@ class CalendarIntervalType private() extends DataType { /** * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 33fc4b9480126..c58f7a2397374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,7 +26,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -111,7 +111,7 @@ abstract class DataType extends AbstractDataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 9e70dd486a125..7491014b22dab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A date type, supporting "0001-01-01" through "9999-12-31". @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. @@ -53,5 +53,5 @@ class DateType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..a3a844670e0c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Unstable import org.apache.spark.sql.AnalysisException /** @@ -31,7 +31,7 @@ import org.apache.spark.sql.AnalysisException * - If decimalVal is set, it represents the whole decimal value * - Otherwise, the decimal value is longVal / (10 ** _scale) */ -@InterfaceStability.Unstable +@Unstable final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ @@ -407,7 +407,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } -@InterfaceStability.Unstable +@Unstable object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 15004e4b9667d..25eddaf06a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -21,11 +21,10 @@ import java.util.Locale import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} - /** * The data type representing `java.math.BigDecimal` values. * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number @@ -39,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (scale > precision) { @@ -110,7 +109,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DecimalType extends AbstractDataType { import scala.math.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index a5c79ff01ca06..afd3353397019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. @@ -54,5 +54,5 @@ class DoubleType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 352147ec936c9..6d98987304081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. @@ -55,5 +55,5 @@ class FloatType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a85e3729188d9..0755202d20df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Int` values. Please use the singleton `DataTypes.IntegerType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. @@ -51,5 +51,5 @@ class IntegerType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 0997028fc1057..3c49c721fdc88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Long` values. Please use the singleton `DataTypes.LongType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. @@ -51,5 +51,5 @@ class LongType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 594e155268bf6..29b9ffc0c3549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ -@InterfaceStability.Stable +@Stable case class MapType( keyType: DataType, valueType: DataType, @@ -78,7 +78,7 @@ case class MapType( /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 7c15dc0de4b6b..4979aced145c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** @@ -37,7 +37,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { @@ -117,7 +117,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Metadata { private[this] val _empty = new Metadata(Map.empty) @@ -228,7 +228,7 @@ object Metadata { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index 494225b47a270..14097a5280d50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `NULL` values. Please use the singleton `DataTypes.NullType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class NullType private() extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. @@ -38,5 +37,5 @@ class NullType private() extends DataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 203e85e1c99bd..6756b209f432e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.types import scala.language.existentials -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving -@InterfaceStability.Evolving +@Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException( @@ -38,7 +38,7 @@ object ObjectType extends AbstractDataType { /** * Represents a JVM object that is passing through Spark SQL expression evaluation. */ -@InterfaceStability.Evolving +@Evolving case class ObjectType(cls: Class[_]) extends DataType { override def defaultSize: Int = 4096 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index ee655c338b59f..9b5ddfef1ccf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Short` values. Please use the singleton `DataTypes.ShortType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. @@ -51,5 +51,5 @@ class ShortType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 59b124cda7d14..8ce1cd078e312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.unsafe.types.UTF8String /** @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. @@ -48,6 +48,6 @@ class StringType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object StringType extends StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 35f9970a0aaec..6f6b561d67d49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdenti * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructField( name: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 06289b1483203..3bef75d5bdb6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,10 +24,10 @@ import scala.util.control.NonFatal import org.json4s.JsonDSL._ import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.util.Utils /** @@ -95,7 +95,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ @@ -422,7 +422,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index fdb91e0499920..a20f155418f8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `java.sql.Timestamp` values. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. @@ -50,5 +50,5 @@ class TimestampType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object TimestampType extends TimestampType diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 802949c0ddb60..d4e1d89491f43 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -33,7 +33,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 353e9886a8a57..f0abfde843cc5 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -32,7 +32,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface MapGroupsWithStateFunction extends Serializable { R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 1c3c9794fb6bb..9cc073f53a3eb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -16,14 +16,14 @@ */ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java index 4eeb7be3f5abb..631d6eb1cfb03 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 0 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF0 extends Serializable { R call() throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index 1460daf27dc20..a5d01406edd8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 1 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF1 extends Serializable { R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java index 7c4f1e4897084..effe99e30b2a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 10 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF10 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java index 26a05106aebd6..e70b18b84b08f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 11 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF11 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java index 8ef7a99042025..339feb34135e1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 12 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF12 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java index 5c3b2ec1222e2..d346e5c908c6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 13 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF13 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java index 97e744d843466..d27f9f5270f4b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 14 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF14 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java index 7ddbf914fc11a..b99b57a91d465 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 15 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF15 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java index 0ae5dc7195ad6..7899fc4b7ad65 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 16 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF16 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java index 03543a556c614..40a7e95724fc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 17 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF17 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java index 46740d3443916..47935a935891c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 18 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF18 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java index 33fefd8ecaf1d..578b796ff03a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 19 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF19 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java index 9822f19217d76..2f856aa3cf630 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 2 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF2 extends Serializable { R call(T1 t1, T2 t2) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java index 8c5e90182da1c..aa8a9fa897040 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 20 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF20 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java index e3b09f5167cff..0fe52bce2eca2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 21 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF21 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java index dc6cfa9097bab..69fd8ca422833 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 22 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF22 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java index 7c264b69ba195..84ffd655672a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 3 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF3 extends Serializable { R call(T1 t1, T2 t2, T3 t3) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java index 58df38fc3c911..dd2dc285c226d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 4 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF4 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java index 4146f96e2eed5..795cc21c3f76e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 5 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF5 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java index 25d39654c1095..a954684c3c9a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 6 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF6 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java index ce63b6a91adbb..03761f2c9ebbf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 7 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF7 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java index 0e00209ef6b9f..8cd3583b2cbf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 8 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF8 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java index 077981bb3e3ee..78a7097791963 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 9 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF9 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java index 82a1169cbe7ae..7d1fbe64fc960 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.datasources; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Unstable; /** * Exception thrown when the parquet reader find column type mismatches. */ -@InterfaceStability.Unstable +@Unstable public class SchemaColumnConvertNotSupportedException extends RuntimeException { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index ec9c107b1c119..5a72f0c6a2555 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions.javalang; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; @@ -35,7 +35,7 @@ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class typed { // Note: make sure to keep in sync with typed.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java index f403dc619e86c..2a4933d75e8d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link BatchReadSupport} instances when end users run * {@code SparkSession.read.format(...).option(...).load()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index bd10c3353bf12..df439e2c02fe3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -19,7 +19,7 @@ import java.util.Optional; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * This interface is used to create {@link BatchWriteSupport} instances when end users run * {@code Dataset.write.format(...).option(...).save()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java index 824c290518acf..b4f2eb34a1560 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link ContinuousReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index 83df3be747085..1c5e3a0cd31e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -26,7 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent @@ -73,7 +73,7 @@ * * */ -@InterfaceStability.Evolving +@Evolving public class DataSourceOptions { private final Map keyLowerCasedMap; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6e31e84bf6c72..eae7a45d1d446 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. @@ -30,5 +30,5 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java index 61c08e7fa89df..c4d9ef88f607e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link MicroBatchReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index bbe430e299261..c00abd9b685b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ -@InterfaceStability.Evolving +@Evolving public interface SessionConfigSupport extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java index f9ca85d8089b4..8ac9c51750865 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.sql.streaming.OutputMode; @@ -30,7 +30,7 @@ * This interface is used to create {@link StreamingWriteSupport} instances when end users run * {@code Dataset.writeStream.format(...).option(...).start()}. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java index 452ee86675b42..518a8b03a2c6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to load the data from data source for batch processing. @@ -29,7 +29,7 @@ * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader * factory to scan data from the data source with a Spark job. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupport extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 95c30de907e44..5f5248084bad6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A serializable representation of an input partition returned by @@ -32,7 +32,7 @@ * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} * doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface InputPartition extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index 04ff8d0a19fc3..2945925959538 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -20,7 +20,7 @@ import java.io.Closeable; import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or @@ -32,7 +32,7 @@ * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} * returns true). */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReader extends Closeable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java index f35de9310eee3..97f4a473953fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -30,7 +30,7 @@ * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReaderFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java index a58ddb288f1ed..b1f610a82e8a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -27,7 +27,7 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index 7462ce2820585..a69872a527746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -31,7 +31,7 @@ * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfig { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index 4c0eedfddfe22..4922962f70655 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface for building the {@link ScanConfig}. Implementations can mixin those * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in * the returned {@link ScanConfig}. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfigBuilder { ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 44799c7d49137..14776f37fed46 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -19,13 +19,13 @@ import java.util.OptionalLong; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface to represent statistics for a data source, which is returned by * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ -@InterfaceStability.Evolving +@Evolving public interface Statistics { OptionalLong sizeInBytes(); OptionalLong numRows(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 5e7985f645a06..3a89baa1b44c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.Filter; /** * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to * push down filters to the data source and reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownFilters extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index edb164937d6ef..1934763224881 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -25,7 +25,7 @@ * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index db62cd4515362..0335c7775c2af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** @@ -27,7 +27,7 @@ * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportPartitioning extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 1831488ba096f..917372cdd25b3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to @@ -27,7 +27,7 @@ * data source. Implementations that return more accurate statistics based on pushed operators will * not improve query performance until the planner can push operators before getting stats. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportStatistics extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 6764d4b7665c7..1cdc02f5736b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -25,7 +25,7 @@ * share the same values for the {@link #clusteredColumns} will be produced by the same * {@link PartitionReader}. */ -@InterfaceStability.Evolving +@Evolving public class ClusteredDistribution implements Distribution { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 364a3f553923c..02b0e68974919 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -37,5 +37,5 @@ *
  • {@link ClusteredDistribution}
  • * */ -@InterfaceStability.Evolving +@Evolving public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index fb0b6f1df43bb..c9a00262c1287 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; @@ -28,7 +28,7 @@ * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ -@InterfaceStability.Evolving +@Evolving public interface Partitioning { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 9101c8a44d34e..c7f6fce6e81af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A variation on {@link PartitionReader} for use with continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReader extends PartitionReader { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index 2d9f1ca1686a1..41195befe5e57 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; @@ -28,7 +28,7 @@ * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for * continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { @Override ContinuousPartitionReader createReader(InputPartition partition); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java index 9a3ad2eb8a801..2b784ac0e9f35 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; @@ -36,7 +36,7 @@ * {@link #stop()} will be called when the streaming execution is completed. Note that a single * query may have multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java index edb0db11bff2c..f56066c639388 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.*; @@ -33,7 +33,7 @@ * will be called when the streaming execution is completed. Note that a single query may have * multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6cf27734867cb..6104175d2c9e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An abstract representation of progress through a {@link MicroBatchReadSupport} or @@ -30,7 +30,7 @@ * maintain compatibility with DataSource V1 APIs. This extension will be removed once we * get rid of V1 completely. */ -@InterfaceStability.Evolving +@Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 383e73db6762b..2c97d924a0629 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will @@ -27,6 +27,6 @@ * * These offsets must be serializable. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 0ec9e05d6a02b..efe1ac4f78db1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to write the data to data source for batch processing. @@ -37,7 +37,7 @@ * * Please refer to the documentation of commit/abort methods for detailed specifications. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 5fb067966ee67..d142ee523ef9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -19,7 +19,7 @@ import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is @@ -55,7 +55,7 @@ * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriter { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 19a36dd232456..65105f46b82d5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; /** @@ -31,7 +31,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 123335c414e9f..9216e34399092 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,8 +19,8 @@ import java.io.Serializable; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side @@ -30,5 +30,5 @@ * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. */ -@InterfaceStability.Evolving +@Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index a4da24fc5ae68..7d3d21cb2b637 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.writer.DataWriter; @@ -33,7 +33,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingDataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index 3fdfac5e1c84a..84cfbf2dda483 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; @@ -27,7 +27,7 @@ * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 5371a23230c98..fd6f7be2abc5a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -19,9 +19,9 @@ import java.util.concurrent.TimeUnit; +import org.apache.spark.annotation.Evolving; import scala.concurrent.duration.Duration; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; @@ -30,7 +30,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class Trigger { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5f58b031f6aef..906e9bc26ef53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -22,7 +22,7 @@ import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; @@ -31,7 +31,7 @@ * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ -@InterfaceStability.Evolving +@Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index ad99b450a4809..14caaeaedbe2b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; @@ -47,7 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ -@InterfaceStability.Evolving +@Evolving public abstract class ColumnVector implements AutoCloseable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72a192d089b9f..dd2bd789c26d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -25,7 +25,7 @@ /** * Array abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index d206c1df42abb..07546a54013ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,7 +18,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -27,7 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index f2f2279590023..4b9d3c5f59915 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.types.*; @@ -26,7 +26,7 @@ /** * Row abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 49d2a34080b13..5a408b29f9337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -60,7 +60,7 @@ private[sql] object Column { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class TypedColumn[-T, U]( expr: Expression, private[sql] val encoder: ExpressionEncoder[U]) @@ -130,7 +130,7 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class Column(val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -1227,7 +1227,7 @@ class Column(val expr: Expression) extends Logging { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 5288907b7d7ff..53e9f810d7c85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,18 +22,17 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ - /** * Functionality for working with missing data in `DataFrame`s. * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index df18623e42a02..52df13d39caa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.Partition -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7c12432d33c33..b2f6a6ba83108 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col @@ -33,7 +33,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1b4998f94b25d..29d479f542115 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,7 +21,7 @@ import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f98eaa3d4eb90..f5caaf3f7fc87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,19 +21,17 @@ import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.spark.TaskContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ @@ -166,10 +164,10 @@ private[sql] object Dataset { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class Dataset[T] private[sql]( @transient val sparkSession: SparkSession, - @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution, + @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -426,7 +424,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** @@ -544,7 +542,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def isStreaming: Boolean = logicalPlan.isStreaming /** @@ -557,7 +555,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** @@ -570,7 +568,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true) /** @@ -583,7 +581,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) /** @@ -596,7 +594,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint( eager = eager, reliableCheckpoint = false @@ -671,7 +669,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { @@ -1066,7 +1064,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. @@ -1142,7 +1140,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -1384,7 +1382,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1418,7 +1416,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] @@ -1430,7 +1428,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1445,7 +1443,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1461,7 +1459,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1632,7 +1630,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: (T, T) => T): T = withNewRDDExecutionId { rdd.reduce(func) } @@ -1647,7 +1645,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** @@ -1659,7 +1657,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1681,7 +1679,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) @@ -2483,7 +2481,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: T => Boolean): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2497,7 +2495,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: FilterFunction[T]): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2511,7 +2509,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { MapElements[T, U](func, logicalPlan) } @@ -2525,7 +2523,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) @@ -2540,7 +2538,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, @@ -2557,7 +2555,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) @@ -2588,7 +2586,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) @@ -2602,7 +2600,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -3064,7 +3062,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 08aa1bbe78fae..1c4ffefb897ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A container for a [[Dataset]], used for implicit conversions in Scala. @@ -30,7 +30,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index bd8dd6ea3fe0f..302d38cde1430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable class ExperimentalMethods private[sql]() { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 52b8c839643e7..5c0fe798b1044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * The abstract class for writing custom logic to process data generated by a query. @@ -104,7 +104,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 555bcdffb6ee4..7a47242f69381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} @@ -37,7 +37,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], @@ -237,7 +237,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) @@ -272,7 +272,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { @@ -309,7 +309,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -340,7 +340,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -371,7 +371,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -413,7 +413,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d4e75b5ebd405..e85636d82a62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} @@ -45,7 +45,7 @@ import org.apache.spark.sql.types.{NumericType, StructType} * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 3c39579149fff..5a554eff02e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} import org.apache.spark.sql.internal.SQLConf - /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. * @@ -29,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9982b60fefe60..43f34e6ff4b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry @@ -54,7 +54,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ -@InterfaceStability.Stable +@Stable class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { @@ -86,7 +86,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * that listen for execution metrics. */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** @@ -158,7 +158,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ @Experimental @transient - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sparkSession.experimental /** @@ -244,7 +244,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self } @@ -258,7 +258,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { sparkSession.createDataFrame(rdd) } @@ -271,7 +271,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { sparkSession.createDataFrame(data) } @@ -319,7 +319,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -363,7 +363,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -401,7 +401,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -428,7 +428,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -443,7 +443,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.6.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rows, schema) } @@ -507,7 +507,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = sparkSession.readStream @@ -631,7 +631,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** @@ -643,7 +643,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** @@ -655,7 +655,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): DataFrame = { sparkSession.range(start, end, step).toDF() } @@ -670,7 +670,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { sparkSession.range(start, end, step, numPartitions).toDF() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 05db292bd41b1..d329af0145c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,7 +21,7 @@ import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * * @since 1.6.0 */ -@InterfaceStability.Evolving +@Evolving abstract class SQLImplicits extends LowPrioritySQLImplicits { protected def _sqlContext: SQLContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c0727e844a1ca..725db97df4ed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ import org.apache.spark.util.{CallSite, Utils} * @param parentSessionState If supplied, inherit all session state (i.e. temporary * views, SQL config, UDFs etc) from parent. */ -@InterfaceStability.Stable +@Stable class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], @@ -124,7 +124,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sharedState: SharedState = { existingSharedState.getOrElse(new SharedState(sparkContext)) @@ -145,7 +145,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sessionState: SessionState = { parentSessionState @@ -186,7 +186,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** @@ -197,7 +197,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sessionState.experimentalMethods /** @@ -231,7 +231,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager /** @@ -289,7 +289,7 @@ class SparkSession private( * @return 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) @@ -302,7 +302,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkSession.setActiveSession(this) val encoder = Encoders.product[A] @@ -316,7 +316,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -356,7 +356,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema, needsConversion = true) } @@ -370,7 +370,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD.rdd, schema) } @@ -384,7 +384,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -474,7 +474,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -493,7 +493,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -515,7 +515,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -528,7 +528,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** @@ -539,7 +539,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -552,7 +552,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -566,7 +566,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } @@ -672,7 +672,7 @@ class SparkSession private( * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = new DataStreamReader(self) /** @@ -706,7 +706,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } @@ -775,13 +775,13 @@ class SparkSession private( } -@InterfaceStability.Stable +@Stable object SparkSession extends Logging { /** * Builder for [[SparkSession]]. */ - @InterfaceStability.Stable + @Stable class Builder extends Logging { private[this] val options = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a4864344b2d25..5ed76789786bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -66,7 +66,7 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi @Experimental -@InterfaceStability.Unstable +@Unstable class SparkSessionExtensions { type RuleBuilder = SparkSession => Rule[LogicalPlan] type CheckRuleBuilder = SparkSession => LogicalPlan => Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 84da097be53c1..5a3f556c9c074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,7 +22,7 @@ import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index ab81725def3f4..44668610d8052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental, Stable} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -29,7 +29,7 @@ import org.apache.spark.storage.StorageLevel * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable abstract class Catalog { /** @@ -233,7 +233,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String): DataFrame /** @@ -261,7 +261,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String, source: String): DataFrame /** @@ -292,7 +292,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -330,7 +330,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -366,7 +366,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -406,7 +406,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index c0c5ebc2ba2d6..cb270875228ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.DefinedByConstructorParams @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * @param locationUri path (in the form of a uri) to data files. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Database( val name: String, @Nullable val description: String, @@ -61,7 +61,7 @@ class Database( * @param isTemporary whether the table is a temporary table. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Table( val name: String, @Nullable val database: String, @@ -93,7 +93,7 @@ class Table( * @param isBucket whether the column is a bucket column. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Column( val name: String, @Nullable val description: String, @@ -126,7 +126,7 @@ class Column( * @param isTemporary whether the function is a temporary function or not. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Function( val name: String, @Nullable val database: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 19e3e55cb2829..4c0db3cb42a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.streaming.Trigger /** @@ -25,5 +25,5 @@ import org.apache.spark.sql.streaming.Trigger * the query. */ @Experimental -@InterfaceStability.Evolving +@Evolving case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala index 90e1766c4d9f1..caffcc3c4c1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala @@ -23,15 +23,15 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.CalendarInterval /** * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at * the specified interval. */ -@InterfaceStability.Evolving +@Evolving case class ContinuousTrigger(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 1e076207bc607..6b4def35e1955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} +import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index eb956c4b3e888..58a942afe28c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d50031bb20621..3d8d931af218e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable object Window { /** @@ -234,5 +234,5 @@ object Window { * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index b7f3000880aca..58227f075f2c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 3e637d594caf3..1cb579c4faa76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.aggregate._ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving // scalastyle:off object typed { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 4976b875fa298..4e8cb3a6ddd66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -159,7 +159,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b2a6e22cbfc86..1cf2a30c0c8bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -68,7 +68,7 @@ import org.apache.spark.util.Utils * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable // scalastyle:off object functions { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f67cc32c15dd2..ac07e1f6bb4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -50,7 +50,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. */ @Experimental -@InterfaceStability.Unstable +@Unstable abstract class BaseSessionStateBuilder( val session: SparkSession, val parentState: Option[SessionState] = None) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index accbea41b9603..b34db581ca2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ @@ -124,7 +124,7 @@ private[sql] object SessionState { * Concrete implementation of a [[BaseSessionStateBuilder]]. */ @Experimental -@InterfaceStability.Unstable +@Unstable class SessionStateBuilder( session: SparkSession, parentState: Option[SessionState] = None) @@ -135,7 +135,7 @@ class SessionStateBuilder( /** * Session shared [[FunctionResourceLoader]]. */ -@InterfaceStability.Unstable +@Unstable class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f76c1fae562c6..230b43022b02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -21,8 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.annotation.{DeveloperApi, Evolving, Since} import org.apache.spark.sql.types._ /** @@ -34,7 +33,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** @@ -57,7 +56,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. @@ -197,7 +196,7 @@ abstract class JdbcDialect extends Serializable { * sure to register your dialects first. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving object JdbcDialects { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 354660e9d5943..61875931d226e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Unstable} import org.apache.spark.sql.execution.SparkStrategy /** @@ -40,7 +40,7 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi - @InterfaceStability.Unstable + @Unstable type Strategy = SparkStrategy type DataFrame = Dataset[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index bdd8c4da6bd30..3f941cc6e1072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class Filter { /** * List of columns that are referenced by this filter. @@ -48,7 +48,7 @@ abstract class Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -60,7 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -71,7 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -82,7 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -93,7 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -104,7 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -114,7 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class In(attribute: String, values: Array[Any]) extends Filter { override def hashCode(): Int = { var h = attribute.hashCode @@ -141,7 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -151,7 +151,7 @@ case class IsNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -161,7 +161,7 @@ case class IsNotNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -171,7 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -181,7 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references } @@ -192,7 +192,7 @@ case class Not(child: Filter) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -203,7 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -214,7 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6057a795c8bf5..6ad054c9f6403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable trait DataSourceRegister { /** @@ -65,7 +65,7 @@ trait DataSourceRegister { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -96,7 +96,7 @@ trait RelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -117,7 +117,7 @@ trait SchemaRelationProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSourceProvider { /** @@ -148,7 +148,7 @@ trait StreamSourceProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSinkProvider { def createSink( sqlContext: SQLContext, @@ -160,7 +160,7 @@ trait StreamSinkProvider { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait CreatableRelationProvider { /** * Saves a DataFrame to a destination (using data source-specific parameters) @@ -192,7 +192,7 @@ trait CreatableRelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -242,7 +242,7 @@ abstract class BaseRelation { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait TableScan { def buildScan(): RDD[Row] } @@ -253,7 +253,7 @@ trait TableScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -271,7 +271,7 @@ trait PrunedScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } @@ -293,7 +293,7 @@ trait PrunedFilteredScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } @@ -309,7 +309,7 @@ trait InsertableRelation { * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index bf6021e692382..e4250145a1ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** * Specifies the input data source format. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b36a8f3f6f15b..5733258a6b310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -39,7 +39,7 @@ import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() @@ -365,7 +365,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { this.source = "foreachBatch" if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") @@ -386,7 +386,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index e9510c903acae..ab68eba81b843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.KeyValueGroupedDataset +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** @@ -192,7 +191,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving trait GroupState[S] extends LogicalGroupState[S] { /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index a033575d3d38f..236bd55ee6212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.unsafe.types.CalendarInterval /** @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") @@ -59,7 +59,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index f2dfbe42260d7..47ddc88e964e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.sql.SparkSession /** @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession * All these methods are thread-safe. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving trait StreamingQuery { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 03aeb14de502a..646d6888b2a16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 6aa82b89ede81..916d6a0365965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent /** @@ -28,7 +28,7 @@ import org.apache.spark.scheduler.SparkListenerEvent * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class StreamingQueryListener { import StreamingQueryListener._ @@ -67,14 +67,14 @@ abstract class StreamingQueryListener { * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving object StreamingQueryListener { /** * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving trait Event extends SparkListenerEvent /** @@ -84,7 +84,7 @@ object StreamingQueryListener { * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryStartedEvent private[sql]( val id: UUID, val runId: UUID, @@ -95,7 +95,7 @@ object StreamingQueryListener { * @param progress The query progress updates. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** @@ -107,7 +107,7 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryTerminatedEvent private[sql]( val id: UUID, val runId: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index cd52d991d55c9..d9ea8dc9d4ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -42,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index a0c9bcc8929eb..9dc62b7aac891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,7 +22,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Reports information about the instantaneous status of a streaming query. @@ -34,7 +34,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryStatus protected[sql]( val message: String, val isDataAvailable: Boolean, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index f2173aa1e59c2..3cd6700efef5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,12 +29,12 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@InterfaceStability.Evolving +@Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, @@ -94,7 +94,7 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryProgress private[sql]( val id: UUID, val runId: UUID, @@ -165,7 +165,7 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SourceProgress protected[sql]( val description: String, val startOffset: String, @@ -209,7 +209,7 @@ class SourceProgress protected[sql]( * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 1310fdfa1356b..77ae047705de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.util import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental} import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession @@ -36,7 +36,7 @@ import org.apache.spark.util.{ListenerBus, Utils} * multiple different threads. */ @Experimental -@InterfaceStability.Evolving +@Evolving trait QueryExecutionListener { /** @@ -73,7 +73,7 @@ trait QueryExecutionListener { * Manager for [[QueryExecutionListener]]. See `org.apache.spark.sql.SQLContext.listenerManager`. */ @Experimental -@InterfaceStability.Evolving +@Evolving // The `session` is used to indicate which session carries this listener manager, and we only // catch SQL executions which are launched by the same session. // The `loadExtensions` flag is used to indicate whether we should load the pre-defined, diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java index bfe50c7810f73..fc2171dc99e4c 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java @@ -148,7 +148,7 @@ public TColumn() { super(); } - public TColumn(_Fields setField, Object value) { + public TColumn(TColumn._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java index 44da2cdd089d6..8504c6d608d42 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java @@ -142,7 +142,7 @@ public TColumnValue() { super(); } - public TColumnValue(_Fields setField, Object value) { + public TColumnValue(TColumnValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java index 4fe59b1c51462..fe2a211c46309 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java @@ -136,7 +136,7 @@ public TGetInfoValue() { super(); } - public TGetInfoValue(_Fields setField, Object value) { + public TGetInfoValue(TGetInfoValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java index af7c0b4f15d95..d0d70c1279572 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java @@ -136,7 +136,7 @@ public TTypeEntry() { super(); } - public TTypeEntry(_Fields setField, Object value) { + public TTypeEntry(TTypeEntry._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java index 8c40687a0aab7..a3e3829372276 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java @@ -112,7 +112,7 @@ public TTypeQualifierValue() { super(); } - public TTypeQualifierValue(_Fields setField, Object value) { + public TTypeQualifierValue(TTypeQualifierValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java index 9dd0efc03968d..7e557aeccf5b0 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java @@ -36,7 +36,7 @@ public abstract class AbstractService implements Service { /** * Service state: initially {@link STATE#NOTINITED}. */ - private STATE state = STATE.NOTINITED; + private Service.STATE state = STATE.NOTINITED; /** * Service name. @@ -70,7 +70,7 @@ public AbstractService(String name) { } @Override - public synchronized STATE getServiceState() { + public synchronized Service.STATE getServiceState() { return state; } @@ -159,7 +159,7 @@ public long getStartTime() { * if the service state is different from * the desired state */ - private void ensureCurrentState(STATE currentState) { + private void ensureCurrentState(Service.STATE currentState) { ServiceOperations.ensureCurrentState(state, currentState); } @@ -173,7 +173,7 @@ private void ensureCurrentState(STATE currentState) { * @param newState * new service state */ - private void changeState(STATE newState) { + private void changeState(Service.STATE newState) { state = newState; // notify listeners for (ServiceStateChangeListener l : listeners) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java index 5a508745414a7..15551da4785f6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java @@ -71,7 +71,7 @@ public HiveConf getHiveConf() { } @Override - public STATE getServiceState() { + public Service.STATE getServiceState() { return service.getServiceState(); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c4..4f3914740ec20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo * Builder that produces a Hive-aware `SessionState`. */ @Experimental -@InterfaceStability.Unstable +@Unstable class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { From ce2cdc36e29742dda22200963cfd3f9876170455 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 08:07:20 -0600 Subject: [PATCH 071/145] [SPARK-26043][CORE] Make SparkHadoopUtil private to Spark ## What changes were proposed in this pull request? Make SparkHadoopUtil private to Spark ## How was this patch tested? Existing tests. Closes #23066 from srowen/SPARK-26043. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../apache/spark/deploy/SparkHadoopUtil.scala | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 5979151345415..217e5145f1c56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,10 +18,9 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} -import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date, Locale} +import java.util.{Arrays, Date, Locale} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -38,17 +37,13 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Contains util methods to interact with Hadoop from Spark. */ -@DeveloperApi -class SparkHadoopUtil extends Logging { +private[spark] class SparkHadoopUtil extends Logging { private val sparkConf = new SparkConf(false).loadFromSystemProperties(true) val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) @@ -274,11 +269,10 @@ class SparkHadoopUtil extends Logging { name.startsWith(prefix) && !name.endsWith(exclusionSuffix) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { + Arrays.sort(fileStatuses, + (o1: FileStatus, o2: FileStatus) => { Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) + }) fileStatuses } catch { case NonFatal(e) => @@ -388,7 +382,7 @@ class SparkHadoopUtil extends Logging { } -object SparkHadoopUtil { +private[spark] object SparkHadoopUtil { private lazy val instance = new SparkHadoopUtil From b58b1fdf906d9609321824fc0bb892b986763b3e Mon Sep 17 00:00:00 2001 From: "Liu,Linhong" Date: Mon, 19 Nov 2018 22:09:44 +0800 Subject: [PATCH 072/145] [SPARK-26068][CORE] ChunkedByteBufferInputStream should handle empty chunks correctly ## What changes were proposed in this pull request? Empty chunk in ChunkedByteBuffer will truncate the ChunkedByteBufferInputStream. The detail reason is described in: https://issues.apache.org/jira/browse/SPARK-26068 ## How was this patch tested? Modified current UT to cover this case. Closes #23040 from LinhongLiu/fix-empty-chunked-byte-buffer. Lead-authored-by: Liu,Linhong Co-authored-by: Xianjin YE Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/util/io/ChunkedByteBuffer.scala | 3 ++- .../scala/org/apache/spark/io/ChunkedByteBufferSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 870830fff4c3e..128d6ff8cd746 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -222,7 +222,8 @@ private[spark] class ChunkedByteBufferInputStream( dispose: Boolean) extends InputStream { - private[this] var chunks = chunkedByteBuffer.getChunks().iterator + // Filter out empty chunks since `read()` assumes all chunks are non-empty. + private[this] var chunks = chunkedByteBuffer.getChunks().filter(_.hasRemaining).iterator private[this] var currentChunk: ByteBuffer = { if (chunks.hasNext) { chunks.next() diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index ff117b1c21cb1..083c5e696b753 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -90,7 +90,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) - val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, empty, bytes2)) assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) val inputStream = chunkedByteBuffer.toInputStream(dispose = false) From 48ea64bf5bd4201c6a7adca67e20b75d23c223f6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 19 Nov 2018 22:18:20 +0800 Subject: [PATCH 073/145] [SPARK-26112][SQL] Update since versions of new built-in functions. ## What changes were proposed in this pull request? The following 5 functions were removed from branch-2.4: - map_entries - map_filter - transform_values - transform_keys - map_zip_with We should update the since version to 3.0.0. ## How was this patch tested? Existing tests. Closes #23082 from ueshin/issues/SPARK-26112/since. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/functions.py | 2 +- .../catalyst/expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/higherOrderFunctions.scala | 12 ++++++------ .../main/scala/org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9abb7fc1fadb4..f72645a257796 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3370,7 +3370,7 @@ setMethod("flatten", #' #' @rdname column_collection_functions #' @aliases map_entries map_entries,Column-method -#' @note map_entries since 2.4.0 +#' @note map_entries since 3.0.0 setMethod("map_entries", signature(x = "Column"), function(x) { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e86749cc15c35..286ef219a69e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2576,7 +2576,7 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) -@since(2.4) +@since(3.0) def map_entries(col): """ Collection function: Returns an unordered array of all entries in the given map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b24d7486f3454..3c260954a72a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -350,7 +350,7 @@ case class MapValues(child: Expression) > SELECT _FUNC_(map(1, 'a', 2, 'b')); [{"key":1,"value":"a"},{"key":2,"value":"b"}] """, - since = "2.4.0") + since = "3.0.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b07d9466ba0d1..0b698f9290711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -264,13 +264,13 @@ case class ArrayTransform( * Filters entries in a map using the provided function. */ @ExpressionDescription( -usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", -examples = """ + usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", + examples = """ Examples: > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); {1:0,3:-1} """, -since = "2.4.0") + since = "3.0.0") case class MapFilter( argument: Expression, function: Expression) @@ -504,7 +504,7 @@ case class ArrayAggregate( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {2:1,4:2,6:3} """, - since = "2.4.0") + since = "3.0.0") case class TransformKeys( argument: Expression, function: Expression) @@ -554,7 +554,7 @@ case class TransformKeys( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {1:2,2:4,3:6} """, - since = "2.4.0") + since = "3.0.0") case class TransformValues( argument: Expression, function: Expression) @@ -605,7 +605,7 @@ case class TransformValues( > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); {1:"ax",2:"by"} """, - since = "2.4.0") + since = "3.0.0") case class MapZipWith(left: Expression, right: Expression, function: Expression) extends HigherOrderFunction with CodegenFallback { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1cf2a30c0c8bd..efa8f8526387f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3731,7 +3731,7 @@ object functions { /** * Returns an unordered array of all entries in the given map. * @group collection_funcs - * @since 2.4.0 + * @since 3.0.0 */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } From 35c55163555f3671edd02ed0543785af82de07ca Mon Sep 17 00:00:00 2001 From: Julien Date: Mon, 19 Nov 2018 22:24:53 +0800 Subject: [PATCH 074/145] [SPARK-26024][SQL] Update documentation for repartitionByRange Following [SPARK-26024](https://issues.apache.org/jira/browse/SPARK-26024), I noticed the number of elements in each partition after repartitioning using `df.repartitionByRange` can vary for the same setup: ```scala // Shuffle numbers from 0 to 1000, and make a DataFrame val df = Random.shuffle(0.to(1000)).toDF("val") // Repartition it using 3 partitions // Sum up number of elements in each partition, and collect it. // And do it several times for (i <- 0 to 9) { var counts = df.repartitionByRange(3, col("val")) .mapPartitions{part => Iterator(part.size)} .collect() println(counts.toList) } // -> the number of elements in each partition varies ``` This is expected as for performance reasons this method uses sampling to estimate the ranges (with default size of 100). Hence, the output may not be consistent, since sampling can return different values. But documentation was not mentioning it at all, leading to misunderstanding. ## What changes were proposed in this pull request? Update the documentation (Spark & PySpark) to mention the impact of `spark.sql.execution.rangeExchange.sampleSizePerPartition` on the resulting partitioned DataFrame. Closes #23025 from JulienPeloton/SPARK-26024. Authored-by: Julien Signed-off-by: Wenchen Fan --- R/pkg/R/DataFrame.R | 8 ++++++++ python/pyspark/sql/dataframe.py | 5 +++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c99ad76f7643c..52e76570139e2 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -767,6 +767,14 @@ setMethod("repartition", #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} #' +#' At least one partition-by expression must be specified. +#' When no explicit sort order is specified, "ascending nulls first" is assumed. +#' +#' Note that due to performance reasons this method uses sampling to estimate the ranges. +#' Hence, the output may not be consistent, since sampling can return different values. +#' The sample size can be controlled by the config +#' \code{spark.sql.execution.rangeExchange.sampleSizePerPartition}. +#' #' @param x a SparkDataFrame. #' @param numPartitions the number of partitions to use. #' @param col the column by which the range partitioning will be performed. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5748f6c6bd5eb..c4f4d81999544 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -732,6 +732,11 @@ def repartitionByRange(self, numPartitions, *cols): At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. + Note that due to performance reasons this method uses sampling to estimate the ranges. + Hence, the output may not be consistent, since sampling can return different values. + The sample size can be controlled by the config + `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() 2 >>> df.show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5caaf3f7fc87..0e77ec0406257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2787,6 +2787,12 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ @@ -2811,6 +2817,11 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ From 219b037f05636a3a7c8116987c319773f4145b63 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 19 Nov 2018 22:42:24 +0800 Subject: [PATCH 075/145] [SPARK-26071][SQL] disallow map as map key ## What changes were proposed in this pull request? Due to implementation limitation, currently Spark can't compare or do equality check between map types. As a result, map values can't appear in EQUAL or comparison expressions, can't be grouping key, etc. The more important thing is, map loop up needs to do equality check of the map key, and thus can't support map as map key when looking up values from a map. Thus it's not useful to have map as map key. This PR proposes to stop users from creating maps using map type as key. The list of expressions that are updated: `CreateMap`, `MapFromArrays`, `MapFromEntries`, `MapConcat`, `TransformKeys`. I manually checked all the places that create `MapType`, and came up with this list. Note that, maps with map type key still exist, via reading from parquet files, converting from scala/java map, etc. This PR is not to completely forbid map as map key, but to avoid creating it by Spark itself. Motivation: when I was trying to fix the duplicate key problem, I found it's impossible to do it with map type map key. I think it's reasonable to avoid map type map key for builtin functions. ## How was this patch tested? updated test Closes #23045 from cloud-fan/map-key. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 6 +- .../expressions/collectionOperations.scala | 12 +- .../expressions/complexTypeCreator.scala | 14 ++- .../expressions/higherOrderFunctions.scala | 4 + .../spark/sql/catalyst/util/TypeUtils.scala | 10 +- .../CollectionExpressionsSuite.scala | 113 ++++++++++-------- .../expressions/ComplexTypeSuite.scala | 83 +++++++------ .../expressions/ExpressionEvalHelper.scala | 21 +++- .../HigherOrderFunctionsSuite.scala | 41 ++++--- .../inputs/typeCoercion/native/mapconcat.sql | 9 +- .../typeCoercion/native/mapconcat.sql.out | 19 ++- 11 files changed, 203 insertions(+), 129 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 50458e96f7c3f..07079d93f25b6 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,6 +17,8 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. + - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. @@ -117,7 +119,7 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. @@ -303,7 +305,7 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3c260954a72a2..43116743e9952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -521,13 +521,18 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { - var funcName = s"function $prettyName" + val funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"input to $funcName should all be of type map, but it's " + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + if (sameTypeCheck.isFailure) { + sameTypeCheck + } else { + TypeUtils.checkForMapKeyType(dataType.keyType) + } } } @@ -740,7 +745,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { - case Some(_) => TypeCheckResult.TypeCheckSuccess + case Some((mapType, _, _)) => + TypeUtils.checkForMapKeyType(mapType.keyType) case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0361372b6b732..6b77996789f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -161,11 +161,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { "The given values of function map should all be the same type, but they are " + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForMapKeyType(dataType.keyType) } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) .getOrElse(StringType), @@ -224,6 +224,16 @@ case class MapFromArrays(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + val keyType = left.dataType.asInstanceOf[ArrayType].elementType + TypeUtils.checkForMapKeyType(keyType) + } + } + override def dataType: DataType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0b698f9290711..8b31021866220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -514,6 +514,10 @@ case class TransformKeys( override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForMapKeyType(function.dataType) + } + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..2a71fdb7592bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -46,12 +46,20 @@ object TypeUtils { if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - return TypeCheckResult.TypeCheckFailure( + TypeCheckResult.TypeCheckFailure( s"input to $caller should all be the same type, but it's " + types.map(_.catalogString).mkString("[", ", ", "]")) } } + def checkForMapKeyType(keyType: DataType): TypeCheckResult = { + if (keyType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("The key of map cannot be/contain map.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2e0adbb465008..1415b7da6fca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -25,6 +25,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -108,32 +109,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Map Concat") { - val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + val m0 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, valueContainsNull = false)) - val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, valueContainsNull = false)) - val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) - val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) - val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) - val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) - val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2), MapType(ArrayType(IntegerType), IntegerType)) - val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4), MapType(ArrayType(IntegerType), IntegerType)) - val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + val m10 = Literal.create(create_map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + val m11 = Literal.create(create_map(1 -> 2, 3 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m14 = Literal.create(Map(5 -> 6), + val m12 = Literal.create(create_map(5 -> 6), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m15 = Literal.create(Map(7 -> null), + val m13 = Literal.create(create_map(7 -> null), MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) @@ -147,7 +144,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), - Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), @@ -174,7 +171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) // keys that are primitive - checkEvaluation(MapConcat(Seq(m11, m12)), + checkEvaluation(MapConcat(Seq(m9, m10)), ( Array(1, 2, 3, 4), // keys Array("1", "2", "3", "4") // values @@ -189,20 +186,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) ) - // keys that are maps, with overlap - checkEvaluation(MapConcat(Seq(m9, m10)), - ( - Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), - Map(1 -> 2, 3 -> 4)), // keys - Array(1, 2, 3, 4) // values - ) - ) - // both keys and value are primitive and valueContainsNull = false - checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6)) // both keys and value are primitive and valueContainsNull = true - checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + checkEvaluation(MapConcat(Seq(m11, m13)), create_map(1 -> 2, 3 -> 4, 7 -> null)) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) @@ -211,7 +199,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq(mNull)), null) // single map - checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + checkEvaluation(MapConcat(Seq(m0)), create_map("a" -> "1", "b" -> "2")) // no map checkEvaluation(MapConcat(Seq.empty), Map.empty) @@ -245,12 +233,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m1, mNull)).nullable) val mapConcat = MapConcat(Seq( - Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + Literal.create(create_map(Seq(1, 2) -> Seq("a", "b")), MapType( ArrayType(IntegerType, containsNull = false), ArrayType(StringType, containsNull = false), valueContainsNull = false)), - Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + Literal.create(create_map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), MapType( ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), @@ -264,6 +252,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(1, 2) -> Seq("a", "b"), Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null)) + + // map key can't be map + val mapOfMap = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val mapOfMap2 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val map = MapConcat(Seq(mapOfMap, mapOfMap2)) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromEntries") { @@ -274,20 +274,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper StructField("b", valueType))), true) } - def r(values: Any*): InternalRow = create_row(values: _*) + def row(values: Any*): InternalRow = create_row(values: _*) // Primitive-type keys and values val aiType = arrayType(IntegerType, IntegerType) - val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) - val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai0 = Literal.create(Seq(row(1, 10), row(2, 20), row(3, 20)), aiType) + val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType) val ai2 = Literal.create(Seq.empty, aiType) val ai3 = Literal.create(null, aiType) - val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) - val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) - val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType) + val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType) - checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) - checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai0), create_map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) @@ -298,23 +298,36 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // Non-primitive-type keys and values val asType = arrayType(StringType, StringType) - val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) - val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as0 = Literal.create(Seq(row("a", "aa"), row("b", "bb"), row("c", "bb")), asType) + val as1 = Literal.create(Seq(row("a", null), row("b", "bb"), row("c", null)), asType) val as2 = Literal.create(Seq.empty, asType) val as3 = Literal.create(null, asType) - val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) - val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) - val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + val as4 = Literal.create(Seq(row("a", "aa"), row("a", "bb")), asType) + val as5 = Literal.create(Seq(row("a", "aa"), row(null, "bb")), asType) + val as6 = Literal.create(Seq(null, row("b", "bb"), null), asType) - checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) - checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as0), create_map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkEvaluation(MapFromEntries(as6), null) + + // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") - checkEvaluation(MapFromEntries(as6), null) + + // map key can't be map + val structOfMap = row(create_map(1 -> 1), 1) + val map = MapFromEntries(Literal.create( + Seq(structOfMap), + arrayType(keyType = MapType(IntegerType, IntegerType), valueType = IntegerType))) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("Sort Array") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 77aaf55480ec2..d95f42e04e37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ @@ -158,40 +158,32 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { keys.zip(values).flatMap { case (k, v) => Seq(k, v) } } - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateMap(Nil), Map.empty) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(intSeq, longSeq)) + create_map(intSeq, longSeq)) checkEvaluation( CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(strSeq, longSeq)) + create_map(strSeq, longSeq)) checkEvaluation( CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), - createMap(longSeq, strSeq)) + create_map(longSeq, strSeq)) val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), - createMap(intSeq, strWithNull.map(_.value))) - intercept[RuntimeException] { - checkEvaluationWithoutCodegen( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } - intercept[RuntimeException] { - checkEvaluationWithUnsafeProjection( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } + create_map(intSeq, strWithNull.map(_.value))) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + "Cannot use null as map key") + + // ArrayType map key and value val map = CreateMap(Seq( Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), Literal.create(strSeq, ArrayType(StringType, containsNull = false)), @@ -202,15 +194,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), valueContainsNull = false)) - checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) + checkEvaluation(map, create_map(intSeq -> strSeq, (intSeq :+ null) -> (strSeq :+ null))) + + // map key can't be map + val map2 = CreateMap(Seq( + Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)), + Literal(1) + )) + map2.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromArrays") { - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) @@ -228,24 +226,33 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val nullArray = Literal.create(null, ArrayType(StringType, false)) - checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) - checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) - checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + checkEvaluation(MapFromArrays(intArray, longArray), create_map(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), create_map(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), create_map(integerSeq, strSeq)) checkEvaluation( - MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation(MapFromArrays(nullArray, nullArray), null) - intercept[RuntimeException] { - checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) - } - intercept[RuntimeException] { - checkEvaluation( - MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + MapFromArrays(intWithNullArray, strArray), + "Cannot use null as map key") + + // map key can't be map + val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) + val map = MapFromArrays( + Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))), + Literal.create(Seq(1), ArrayType(IntegerType)) + ) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index da18475276a13..eb33325d0b31a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -48,6 +48,25 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } + // Currently MapData just stores the key and value arrays. Its equality is not well implemented, + // as the order of the map entries should not matter for equality. This method creates MapData + // with the entries ordering preserved, so that we can deterministically test expressions with + // map input/output. + protected def create_map(entries: (_, _)*): ArrayBasedMapData = { + create_map(entries.map(_._1), entries.map(_._2)) + } + + protected def create_map(keys: Seq[_], values: Seq[_]): ArrayBasedMapData = { + assert(keys.length == values.length) + val keyArray = CatalystTypeConverters + .convertToCatalyst(keys) + .asInstanceOf[ArrayData] + val valueArray = CatalystTypeConverters + .convertToCatalyst(values) + .asInstanceOf[ArrayData] + new ArrayBasedMapData(keyArray, valueArray) + } + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e13f4d98295be..66bf18af95799 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ @@ -310,13 +311,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformKeys") { val ai0 = Literal.create( - Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + create_map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map.empty[Int, Int], MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai2 = Literal.create( - Map(1 -> 1, 2 -> null, 3 -> 3), + create_map(1 -> 1, 2 -> null, 3 -> 3), MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) @@ -324,26 +325,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusValue: (Expression, Expression) => Expression = (k, v) => k + v val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 - checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) - checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation(transformKeys(ai0, plusOne), create_map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), create_map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) checkEvaluation( - transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + transformKeys(transformKeys(ai0, plusOne), plusValue), + create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) checkEvaluation(transformKeys(ai0, modKey), ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) - checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation(transformKeys(ai2, plusOne), create_map(2 -> 1, 3 -> null, 4 -> 3)) checkEvaluation( - transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + transformKeys(transformKeys(ai2, plusOne), plusOne), create_map(3 -> 1, 4 -> null, 5 -> 3)) checkEvaluation(transformKeys(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType, valueContainsNull = true)) val as2 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) @@ -355,26 +357,35 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (k, v) => Length(k) + 1 checkEvaluation( - transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + transformKeys(as0, concatValue), create_map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) checkEvaluation( transformKeys(transformKeys(as0, concatValue), concatValue), - Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + create_map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) checkEvaluation( transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), Map.empty[Int, String]) checkEvaluation(transformKeys(as0, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + create_map(2 -> "xy", 3 -> "yz", 4 -> "zx")) checkEvaluation(transformKeys(as1, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> null)) + create_map(2 -> "xy", 3 -> "yz", 4 -> null)) checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) val ax0 = Literal.create( - Map(1 -> "x", 2 -> "y", 3 -> "z"), + create_map(1 -> "x", 2 -> "y", 3 -> "z"), MapType(IntegerType, StringType, valueContainsNull = false)) - checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + checkEvaluation(transformKeys(ax0, plusOne), create_map(2 -> "x", 3 -> "y", 4 -> "z")) + + // map key can't be map + val makeMap: (Expression, Expression) => Expression = (k, v) => CreateMap(Seq(k, v)) + val map = transformKeys(ai0, makeMap) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("TransformValues") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql index 69da67fc66fc0..60895020fcc83 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -13,7 +13,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -31,7 +30,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ); @@ -51,7 +49,6 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps; @@ -71,7 +68,7 @@ FROM various_maps; -- Concatenate map of incompatible types 1 SELECT - map_concat(tinyint_map1, map_map2) tm_map + map_concat(tinyint_map1, array_map1) tm_map FROM various_maps; -- Concatenate map of incompatible types 2 @@ -86,10 +83,10 @@ FROM various_maps; -- Concatenate map of incompatible types 4 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps; -- Concatenate map of incompatible types 5 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index efc88e47209a6..79e00860e4c05 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -18,7 +18,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -36,7 +35,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ) @@ -61,14 +59,13 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps -- !query 1 schema -struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,string_int_map:map,int_string_map:map> -- !query 1 output -{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {"a":1,"c":2} {1:"a",2:"c"} -- !query 2 @@ -91,13 +88,13 @@ struct,si_map:map,ib_map:map -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`array_map1`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 -- !query 4 @@ -124,21 +121,21 @@ cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' -- !query 6 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`struct_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,struct>, map,array>]; line 2 pos 4 -- !query 7 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 From 32365f8177f913533d348f7079605a282f1014ef Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 09:16:42 -0600 Subject: [PATCH 076/145] [SPARK-26090][CORE][SQL][ML] Resolve most miscellaneous deprecation and build warnings for Spark 3 ## What changes were proposed in this pull request? The build has a lot of deprecation warnings. Some are new in Scala 2.12 and Java 11. We've fixed some, but I wanted to take a pass at fixing lots of easy miscellaneous ones here. They're too numerous and small to list here; see the pull request. Some highlights: - `BeanInfo` is deprecated in 2.12, and BeanInfo classes are pretty ancient in Java. Instead, case classes can explicitly declare getters - Eta expansion of zero-arg methods; foo() becomes () => foo() in many cases - Floating-point Range is inexact and deprecated, like 0.0 to 100.0 by 1.0 - finalize() is finally deprecated (just needs to be suppressed) - StageInfo.attempId was deprecated and easiest to remove here I'm not now going to touch some chunks of deprecation warnings: - Parquet deprecations - Hive deprecations (particularly serde2 classes) - Deprecations in generated code (mostly Thriftserver CLI) - ProcessingTime deprecations (we may need to revive this class as internal) - many MLlib deprecations because they concern methods that may be removed anyway - a few Kinesis deprecations I couldn't figure out - Mesos get/setRole, which I don't know well - Kafka/ZK deprecations (e.g. poll()) - Kinesis - a few other ones that will probably resolve by deleting a deprecated method ## How was this patch tested? Existing tests, including manual testing with the 2.11 build and Java 11. Closes #23065 from srowen/SPARK-26090. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../spark/util/kvstore/LevelDBIterator.java | 1 + common/unsafe/pom.xml | 5 ++++ .../types/UTF8StringPropertyCheckSuite.scala | 4 +-- .../spark/io/NioBufferedFileInputStream.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 4 ++- .../org/apache/spark/api/r/RBackend.scala | 8 ++--- .../HadoopDelegationTokenManager.scala | 4 ++- .../org/apache/spark/executor/Executor.scala | 7 +++-- .../apache/spark/scheduler/StageInfo.scala | 4 ++- .../rdd/ParallelCollectionSplitSuite.scala | 4 +-- .../serializer/KryoSerializerSuite.scala | 10 +++---- .../spark/status/AppStatusListenerSuite.scala | 17 +++++------ .../ExternalAppendOnlyMapSuite.scala | 1 + .../org/apache/spark/sql/avro/AvroSuite.scala | 20 ++++++------- .../sql/kafka010/KafkaContinuousTest.scala | 4 ++- .../streaming/kafka010/ConsumerStrategy.scala | 6 ++-- .../spark/ml/feature/LabeledPoint.scala | 8 +++-- .../ml/feature/QuantileDiscretizer.scala | 10 ++----- .../spark/mllib/regression/LabeledPoint.scala | 8 +++-- .../spark/mllib/stat/test/StreamingTest.scala | 5 ++-- .../apache/spark/ml/feature/DCTSuite.scala | 8 ++--- .../apache/spark/ml/feature/NGramSuite.scala | 9 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 12 ++++---- .../spark/ml/feature/TokenizerSuite.scala | 8 ++--- .../spark/ml/feature/VectorIndexerSuite.scala | 9 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- pom.xml | 5 ++++ project/MimaExcludes.scala | 5 ++++ .../k8s/submit/KubernetesDriverBuilder.scala | 2 +- .../k8s/KubernetesExecutorBuilder.scala | 2 +- .../spark/deploy/k8s/submit/ClientSuite.scala | 3 +- .../k8s/ExecutorPodsAllocatorSuite.scala | 2 +- .../deploy/yarn/YarnAllocatorSuite.scala | 3 +- .../util/HyperLogLogPlusPlusHelper.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 12 ++++---- .../sql/streaming/StreamingQueryManager.scala | 2 +- .../sql/JavaBeanDeserializationSuite.java | 17 ++++++++++- .../sources/v2/JavaRangeInputPartition.java | 30 +++++++++++++++++++ .../sql/sources/v2/JavaSimpleReadSupport.java | 9 ------ .../spark/sql/UserDefinedTypeSuite.scala | 10 +++---- .../compression/IntegralDeltaSuite.scala | 3 +- .../ProcessingTimeExecutorSuite.scala | 5 +--- .../sources/TextSocketStreamSuite.scala | 5 ++-- .../sql/util/DataFrameCallbackSuite.scala | 2 +- .../HiveCliSessionStateSuite.scala | 2 +- 45 files changed, 177 insertions(+), 125 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index f62e85d435318..e3efc92c4a54a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -196,6 +196,7 @@ public synchronized void close() throws IOException { * when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by * the iterator will eventually be released. */ + @SuppressWarnings("deprecation") @Override protected void finalize() throws Throwable { db.closeIterator(this); diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 7e4b08217f1b0..93a4f67fd23f2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -89,6 +89,11 @@ commons-lang3 test + + org.apache.commons + commons-text + test + target/scala-${scala.binary.version}/classes diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 9656951810daf..fdb81a06d41c9 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types -import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.similarity.LevenshteinDistance import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off @@ -232,7 +232,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty test("levenshteinDistance") { forAll { (one: String, another: String) => assert(toUTF8(one).levenshteinDistance(toUTF8(another)) === - StringUtils.getLevenshteinDistance(one, another)) + LevenshteinDistance.getDefaultInstance.apply(one, another)) } } diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..92bf0ecc1b5cb 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -27,7 +27,7 @@ * to read a file to avoid extra copy of data between Java and * native memory which happens when using {@link java.io.BufferedInputStream}. * Unfortunately, this is not something already available in JDK, - * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio, + * {@code sun.nio.ch.ChannelInputStream} supports reading a file using nio, * but does not support buffering. */ public final class NioBufferedFileInputStream extends InputStream { @@ -130,6 +130,7 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + @SuppressWarnings("deprecation") @Override protected void finalize() throws IOException { close(); diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb91717dfa121..845a3d5f6d6f9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -502,7 +502,9 @@ class SparkContext(config: SparkConf) extends Logging { _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // create and start the heartbeater for collecting memory metrics - _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", + _heartbeater = new Heartbeater(env.memoryManager, + () => SparkContext.this.reportHeartBeat(), + "driver-heartbeater", conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) _heartbeater.start() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 7ce2581555014..50c8fdf5316d6 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.io.{DataOutputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit @@ -32,8 +32,6 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -99,7 +97,7 @@ private[spark] class RBackend { if (bootstrap != null && bootstrap.config().group() != null) { bootstrap.config().group().shutdownGracefully() } - if (bootstrap != null && bootstrap.childGroup() != null) { + if (bootstrap != null && bootstrap.config().childGroup() != null) { bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null @@ -147,7 +145,7 @@ private[spark] object RBackend extends Logging { new Thread("wait for socket to close") { setDaemon(true) override def run(): Unit = { - // any un-catched exception will also shutdown JVM + // any uncaught exception will also shutdown JVM val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 10cd8742f2b49..1169b2878e993 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -270,7 +270,9 @@ private[spark] class HadoopDelegationTokenManager( } private def loadProviders(): Map[String, HadoopDelegationTokenProvider] = { - val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystemsToAccess)) ++ + val providers = Seq( + new HadoopFSDelegationTokenProvider( + () => HadoopDelegationTokenManager.this.fileSystemsToAccess())) ++ safeCreateProvider(new HiveDelegationTokenProvider) ++ safeCreateProvider(new HBaseDelegationTokenProvider) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 61deb543d8747..a30a501e5d4a1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -190,8 +190,11 @@ private[spark] class Executor( private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL) // Executor for the heartbeat task. - private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, - "executor-heartbeater", HEARTBEAT_INTERVAL_MS) + private val heartbeater = new Heartbeater( + env.memoryManager, + () => Executor.this.reportHeartBeat(), + "executor-heartbeater", + HEARTBEAT_INTERVAL_MS) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 903e25b7986f2..33a68f24bd53a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, + private val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + // This would just be the second constructor arg, except we need to maintain this method + // with parentheses for compatibility def attemptNumber(): Int = attemptId private[spark] def getStatusString: String = { diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 31ce9483cf20a..424d9f825c465 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -215,7 +215,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("exclusive ranges of doubles") { - val data = 1.0 until 100.0 by 1.0 + val data = Range.BigDecimal(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 99) @@ -223,7 +223,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("inclusive ranges of doubles") { - val data = 1.0 to 100.0 by 1.0 + val data = Range.BigDecimal.inclusive(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 100) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 84af73b08d3e7..e413fe3b774d0 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -202,7 +202,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit() < 100) + assert(ser.serialize(t).limit() < 200) } check(1 to 1000000) check(1 to 1000000 by 2) @@ -212,10 +212,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(1L to 1000000L by 2L) check(1L until 1000000L) check(1L until 1000000L by 2L) - check(1.0 to 1000000.0 by 1.0) - check(1.0 to 1000000.0 by 2.0) - check(1.0 until 1000000.0 by 1.0) - check(1.0 until 1000000.0 by 2.0) + check(Range.BigDecimal.inclusive(1, 1000000, 1)) + check(Range.BigDecimal.inclusive(1, 1000000, 2)) + check(Range.BigDecimal(1, 1000000, 1)) + check(Range.BigDecimal(1, 1000000, 2)) } test("asJavaIterable") { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index bfd73069fbff8..5f757b757ac61 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.status import java.io.File -import java.lang.{Integer => JInteger, Long => JLong} -import java.util.{Arrays, Date, Properties} +import java.util.{Date, Properties} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -1171,12 +1170,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Stop task 2 before task 1 time += 1 tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(1), null)) time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Start task 3 and task 2 should be evicted. listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) @@ -1241,8 +1240,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Task 1 Finished time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Stage 1 Completed stage1.failureReason = Some("Failed") @@ -1256,7 +1255,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(1).markFinished(TaskState.FINISHED, time) listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", + SparkListenerTaskEnd(stage1.stageId, stage1.attemptNumber, "taskType", TaskKilled(reason = "Killed"), tasks(1), null)) // Ensure killed task metrics are updated diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index cd25265784136..35fba1a3b73c6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import scala.language.postfixOps import scala.ref.WeakReference import org.scalatest.Matchers diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 8d6cca8e48c3d..207c54ce75f4c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -138,7 +138,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test NULL avro type") { withTempPath { dir => val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -161,7 +161,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -189,7 +189,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -221,7 +221,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -247,7 +247,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("Union of a single type") { withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -274,10 +274,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null), - new Field("field2", complexUnionType, "doc", null), - new Field("field3", complexUnionType, "doc", null), - new Field("field4", complexUnionType, "doc", null) + new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal]) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -941,7 +941,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) val name = "foo" - val avroField = new Field(name, avroType, "", null) + val avroField = new Field(name, avroType, "", null.asInstanceOf[AnyVal]) val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) val avroRecordType = resolveNullable(recordSchema, nullable) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa6bdc20bd4f9..aa21f1271b817 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -56,7 +56,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { } // Continuous processing tasks end asynchronously, so test that they actually end. - private val tasksEndedListener = new SparkListener() { + private class TasksEndedListener extends SparkListener { val activeTaskIdCount = new AtomicInteger(0) override def onTaskStart(start: SparkListenerTaskStart): Unit = { @@ -68,6 +68,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { } } + private val tasksEndedListener = new TasksEndedListener() + override def beforeEach(): Unit = { super.beforeEach() spark.sparkContext.addSparkListener(tasksEndedListener) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index cf283a5c3e11e..07960d14b0bfc 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -228,7 +228,7 @@ object ConsumerStrategies { new Subscribe[K, V]( new ju.ArrayList(topics.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -307,7 +307,7 @@ object ConsumerStrategies { new SubscribePattern[K, V]( pattern, new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -391,7 +391,7 @@ object ConsumerStrategies { new Assign[K, V]( new ju.ArrayList(topicPartitions.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index c5d0ec1a8d350..412954f7b2d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector @@ -30,8 +28,12 @@ import org.apache.spark.ml.linalg.Vector * @param features List of features for this data point. */ @Since("2.0.0") -@BeanInfo case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 56e2c543d100a..5bfaa3b7f3f52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,10 +17,6 @@ package org.apache.spark.ml.feature -import org.json4s.JsonDSL._ -import org.json4s.JValue -import org.json4s.jackson.JsonMethods._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -209,7 +205,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui if (isSet(inputCols)) { val splitsArray = if (isSet(numBucketsArray)) { val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => - (0.0 to 1.0 by 1.0 / numOfBuckets).toArray + (0 to numOfBuckets).map(_.toDouble / numOfBuckets).toArray } val probabilityArray = probArrayPerCol.flatten.sorted.distinct @@ -229,12 +225,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } } else { dataset.stat.approxQuantile($(inputCols), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) } bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) } else { val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) bucketizer.setSplits(getDistinctSplits(splits)) } copyValues(bucketizer.setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 4381d6ab20cc0..b320057b25276 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.regression -import scala.beans.BeanInfo - import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} @@ -32,10 +30,14 @@ import org.apache.spark.mllib.util.NumericParser * @param features List of features for this data point. */ @Since("0.8.0") -@BeanInfo case class LabeledPoint @Since("1.0.0") ( @Since("0.8.0") label: Double, @Since("1.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 80c6ef0ea1aa1..85ed11d6553d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.stat.test -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.streaming.api.java.JavaDStream @@ -32,10 +30,11 @@ import org.apache.spark.util.StatCounter * @param value numeric value of the observation. */ @Since("1.6.0") -@BeanInfo case class BinarySample @Since("1.6.0") ( @Since("1.6.0") isExperiment: Boolean, @Since("1.6.0") value: Double) { + def getIsExperiment: Boolean = isExperiment + def getValue: Double = value override def toString: String = { s"($isExperiment, $value)" } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 6734336aac39c..985e396000d05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row -@BeanInfo -case class DCTTestData(vec: Vector, wantedVec: Vector) +case class DCTTestData(vec: Vector, wantedVec: Vector) { + def getVec: Vector = vec + def getWantedVec: Vector = wantedVec +} class DCTSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 201a335e0d7be..1483d5df4d224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} - -@BeanInfo -case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) { + def getInputTokens: Array[String] = inputTokens + def getWantedNGrams: Array[String] = wantedNGrams +} class NGramSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b009038bbd833..82af05039653e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -31,7 +31,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val datasetSize = 100000 val numBuckets = 5 - val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -114,8 +114,8 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val spark = this.spark import spark.implicits._ - val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") - val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") + val trainDF = sc.parallelize((1 to 100).map(_.toDouble)).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize((-10 to 110).map(_.toDouble)).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -276,10 +276,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) val data2 = Array.range(1, 40, 2).map(_.toDouble) val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, - 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) + 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) val data3 = Array.range(1, 60, 3).map(_.toDouble) - val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, - 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) + val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, + 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) val data = (0 until 20).map { idx => (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index be59b0af2c78e..ba8e79f14de95 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -@BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) { + def getRawText: String = rawText + def getWantedTokens: Array[String] = wantedTokens +} class TokenizerSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index fb5789f945dec..44b0f8f8ae7d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { @@ -339,6 +337,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { } private[feature] object VectorIndexerSuite { - @BeanInfo - case class FeatureData(@BeanProperty features: Vector) + case class FeatureData(features: Vector) { + def getFeatures: Vector = features + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a59c41740daf..2fc9754ecfe1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -601,7 +601,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val df = maybeDf.get._2 val expected = estimator.fit(df) - val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + val actuals = dfs.map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } diff --git a/pom.xml b/pom.xml index fcec295eee128..9130773cb5094 100644 --- a/pom.xml +++ b/pom.xml @@ -407,6 +407,11 @@ commons-lang3 ${commons-lang3.version} + + org.apache.commons + commons-text + 1.6 + commons-lang commons-lang diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a8d2b5d1d9cb6..e35e74aa33045 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index be4daec3b1bb9..167fb402cd402 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -55,7 +55,7 @@ private[spark] class KubernetesDriverBuilder( providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => PodTemplateConfigMapStep) = new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 089f84dec277f..fc41a4770bce6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -53,7 +53,7 @@ private[spark] class KubernetesExecutorBuilder( KubernetesConf[KubernetesExecutorSpecificConf] => HadoopSparkUserExecutorFeatureStep) = new HadoopSparkUserExecutorFeatureStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 81e3822389f30..08f28758ef485 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} -import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} +import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} import org.scalatest.BeforeAndAfter @@ -28,7 +28,6 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class ClientSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index b336774838bcb..2f984e5d89808 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -157,7 +157,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { override def matches(argument: scala.Any): Boolean = { - if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + if (!argument.isInstanceOf[KubernetesConf[_]]) { false } else { val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 35299166d9814..c3070de3d17cf 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -116,8 +116,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } def createContainer(host: String, resource: Resource = containerResource): Container = { - // When YARN 2.6+ is required, avoid deprecation by using version with long second arg - val containerId = ContainerId.newInstance(appAttemptId, containerNum) + val containerId = ContainerId.newContainerId(appAttemptId, containerNum) containerNum += 1 val nodeId = NodeId.newInstance(host, 1000) Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index 9bacd3b925be3..ea619c6a7666c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -199,7 +199,7 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { var shift = 0 while (idx < m && i < REGISTERS_PER_WORD) { val Midx = (word >>> shift) & REGISTER_WORD_MASK - zInverse += 1.0 / (1 << Midx) + zInverse += 1.0 / (1L << Midx) if (Midx == 0) { V += 1.0d } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 94778840d706b..117e96175e92a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,8 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class GroupableData(@BeanProperty data: Int) +private[sql] case class GroupableData(data: Int) { + def getData: Int = data +} private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { @@ -50,8 +49,9 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { private[spark] override def asNullable: GroupableUDT = this } -@BeanInfo -private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) +private[sql] case class UngroupableData(data: Map[Int, Int]) { + def getData: Map[Int, Int] = data +} private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index d9ea8dc9d4ac9..d9fe1a992a093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -311,7 +311,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock()): StreamingQuery = { val query = createQuery( userSpecifiedName, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 7f975a647c241..8f35abeb579b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -143,11 +143,16 @@ public void setIntervals(List intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof ArrayRecord)) return false; ArrayRecord other = (ArrayRecord) obj; - return (other.id == this.id) && other.intervals.equals(this.intervals); + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); } @Override @@ -184,6 +189,11 @@ public void setIntervals(Map intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof MapRecord)) return false; @@ -225,6 +235,11 @@ public void setEndTime(long endTime) { this.endTime = endTime; } + @Override + public int hashCode() { + return Long.hashCode(startTime) ^ Long.hashCode(endTime); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof Interval)) return false; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java new file mode 100644 index 0000000000000..438f489a3eea7 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.InputPartition; + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java index 685f9b9747e85..ced51dde6997b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -88,12 +88,3 @@ public void close() throws IOException { } } -class JavaRangeInputPartition implements InputPartition { - int start; - int end; - - JavaRangeInputPartition(int start, int end) { - this.start = start; - this.end = end; - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cc8b600efa46a..cf956316057eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} @@ -28,10 +26,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: UDT.MyDenseVector) +private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { + def getLabel: Double = label + def getFeatures: UDT.MyDenseVector = features +} // Wrapped in an object to check Scala compatibility. See SPARK-13929 object UDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 0d9f1fb0c02c9..fb3388452e4e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -46,6 +46,7 @@ class IntegralDeltaSuite extends SparkFunSuite { (input.tail, input.init).zipped.map { case (x: Int, y: Int) => (x - y).toLong case (x: Long, y: Long) => x - y + case other => fail(s"Unexpected input $other") } } @@ -116,7 +117,7 @@ class IntegralDeltaSuite extends SparkFunSuite { val row = new GenericInternalRow(1) val nullRow = new GenericInternalRow(1) nullRow.setNullAt(0) - input.map { value => + input.foreach { value => if (value == nullValue) { builder.appendFrom(nullRow, 0) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 80c76915e4c23..2d338ab92211e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable - -import org.eclipse.jetty.util.ConcurrentHashSet import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ @@ -48,7 +45,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { } test("trigger timing") { - val triggerTimes = new ConcurrentHashSet[Int] + val triggerTimes = ConcurrentHashMap.newKeySet[Int]() val clock = new StreamManualClock() @volatile var continueExecuting = true @volatile var clockIncrementInTrigger = 0L diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 635ea6fca649c..7db31f1f8f699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -382,10 +382,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before tasks.foreach { case t: TextSocketContinuousInputPartition => val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] - for (i <- 0 until numRecords / 2) { + for (_ <- 0 until numRecords / 2) { r.next() - assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) - .isInstanceOf[(String, Timestamp)]) + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP).isInstanceOf[(_, _)]) } case _ => throw new IllegalStateException("Unexpected task type") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index e8710aeb40bd4..ddc5dbb148cb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -150,7 +150,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { def getPeakExecutionMemory(stageId: Int): Long = { val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables - .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .filter(_._2.name == Some(InternalAccumulator.PEAK_EXECUTION_MEMORY)) assert(peakMemoryAccumulator.size == 1) peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala index 5f9ea4d26790b..035b71a37a692 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.HiveUtils class HiveCliSessionStateSuite extends SparkFunSuite { def withSessionClear(f: () => Unit): Unit = { - try f finally SessionState.detachSession() + try f() finally SessionState.detachSession() } test("CliSessionState will be reused") { From 86cc907448f0102ad0c185e87fcc897d0a32707f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 15:11:42 -0600 Subject: [PATCH 077/145] This is a dummy commit to trigger ASF git sync From a09d5ba88680d07121ce94a4e68c3f42fc635f4f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 20 Nov 2018 09:27:46 +0800 Subject: [PATCH 078/145] [SPARK-26107][SQL] Extend ReplaceNullWithFalseInPredicate to support higher-order functions: ArrayExists, ArrayFilter, MapFilter ## What changes were proposed in this pull request? Extend the `ReplaceNullWithFalse` optimizer rule introduced in SPARK-25860 (https://github.com/apache/spark/pull/22857) to also support optimizing predicates in higher-order functions of `ArrayExists`, `ArrayFilter`, `MapFilter`. Also rename the rule to `ReplaceNullWithFalseInPredicate` to better reflect its intent. Example: ```sql select filter(a, e -> if(e is null, null, true)) as b from ( select array(null, 1, null, 3) as a) ``` The optimized logical plan: **Before**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) null else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` **After**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) false else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` ## How was this patch tested? Added new unit test cases to the `ReplaceNullWithFalseInPredicateSuite` (renamed from `ReplaceNullWithFalseSuite`). Closes #23079 from rednaxelafx/catalyst-master. Authored-by: Kris Mok Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 11 ++++- ...eplaceNullWithFalseInPredicateSuite.scala} | 48 +++++++++++++++++-- ...llWithFalseInPredicateEndToEndSuite.scala} | 45 ++++++++++++++++- 4 files changed, 98 insertions(+), 8 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ReplaceNullWithFalseSuite.scala => ReplaceNullWithFalseInPredicateSuite.scala} (87%) rename sql/core/src/test/scala/org/apache/spark/sql/{ReplaceNullWithFalseEndToEndSuite.scala => ReplaceNullWithFalseInPredicateEndToEndSuite.scala} (63%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a330a84a3a24f..8d251eeab8484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - ReplaceNullWithFalse, + ReplaceNullWithFalseInPredicate, PruneFilters, EliminateSorts, SimplifyCasts, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2b29b49d00ab9..354efd883f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -755,7 +755,7 @@ object CombineConcats extends Rule[LogicalPlan] { * * As a result, many unnecessary computations can be removed in the query optimization phase. */ -object ReplaceNullWithFalse extends Rule[LogicalPlan] { +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) @@ -767,6 +767,15 @@ object ReplaceNullWithFalse extends Rule[LogicalPlan] { replaceNullWithFalse(cond) -> value } cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala similarity index 87% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index c6b5d0ec96776..3a9e6cae0fd87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{BooleanType, IntegerType} -class ReplaceNullWithFalseSuite extends PlanTest { +class ReplaceNullWithFalseInPredicateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -36,10 +36,11 @@ class ReplaceNullWithFalseSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, - ReplaceNullWithFalse) :: Nil + ReplaceNullWithFalseInPredicate) :: Nil } - private val testRelation = LocalRelation('i.int, 'b.boolean) + private val testRelation = + LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { @@ -298,6 +299,26 @@ class ReplaceNullWithFalseSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace nulls in lambda function of ArrayFilter") { + testHigherOrderFunc('a, ArrayFilter, Seq('e)) + } + + test("replace nulls in lambda function of ArrayExists") { + testHigherOrderFunc('a, ArrayExists, Seq('e)) + } + + test("replace nulls in lambda function of MapFilter") { + testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + } + + test("inability to replace nulls in arbitrary higher-order function") { + val lambdaFunc = LambdaFunction( + function = If('e > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression]('e)) + val column = ArrayTransform('a, lambdaFunc) + testProjection(originalExpr = column, expectedExpr = column) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } @@ -310,6 +331,25 @@ class ReplaceNullWithFalseSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testHigherOrderFunc( + argument: Expression, + createExpr: (Expression, Expression) => Expression, + lambdaArgs: Seq[NamedExpression]): Unit = { + val condArg = lambdaArgs.last + // the lambda body is: if(arg > 0, null, true) + val cond = GreaterThan(condArg, Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + // the optimized lambda body is: if(arg > 0, false, true) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = lambdaArgs) + testProjection( + originalExpr = createExpr(argument, lambda1) as 'x, + expectedExpr = createExpr(argument, lambda2) as 'x) + } + private def test( func: (LogicalPlan, Expression) => LogicalPlan, originalExpr: Expression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala similarity index 63% rename from sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index fc6ecc4e032f6..0f84b0c961a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{lit, when} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType -class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { +class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { @@ -68,4 +69,44 @@ class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext case p => fail(s"$p is not LocalTableScanExec") } } + + test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { + def assertNoLiteralNullInPlan(df: DataFrame): Unit = { + df.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(_.find { + case Literal(null, BooleanType) => true + case _ => false + }.isEmpty)) + } + } + + withTable("t1", "t2") { + // to test ArrayFilter and ArrayExists + spark.sql("select array(null, 1, null, 3) as a") + .write.saveAsTable("t1") + // to test MapFilter + spark.sql(""" + select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m + from (select array(0, 1, 2, 3) as a) + """).write.saveAsTable("t2") + + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + // ArrayExists + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + + // ArrayFilter + val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") + checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil) + assertNoLiteralNullInPlan(q2) + + // MapFilter + val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))") + checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3))) + assertNoLiteralNullInPlan(q3) + } + } } From a00aaf649cb5a14648102b2980ce21393804f2c7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 20 Nov 2018 08:27:57 -0600 Subject: [PATCH 079/145] [MINOR][YARN] Make memLimitExceededLogMessage more clean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Current `memLimitExceededLogMessage`: It‘s not very clear, because physical memory exceeds but suggestion contains virtual memory config. This pr makes it more clear and replace deprecated config: ```spark.yarn.executor.memoryOverhead```. ## How was this patch tested? manual tests Closes #23030 from wangyum/EXECUTOR_MEMORY_OVERHEAD. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- .../spark/deploy/yarn/YarnAllocator.scala | 33 ++++++++----------- .../deploy/yarn/YarnAllocatorSuite.scala | 12 ------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebdcf45603cea..9497530805c1a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy.yarn import java.util.Collections import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger -import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable @@ -598,13 +597,21 @@ private[yarn] class YarnAllocator( (false, s"Container ${containerId}${onHostStr} was preempted.") // Should probably still count memory exceeded exit codes towards task failures case VMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) + val vmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX virtual memory used".r + val diag = vmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding virtual memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key} or boosting " + + s"${YarnConfiguration.NM_VMEM_PMEM_RATIO} or disabling " + + s"${YarnConfiguration.NM_VMEM_CHECK_ENABLED} because of YARN-4714." + (true, message) case PMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) + val pmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX physical memory used".r + val diag = pmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding physical memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}." + (true, message) case _ => // all the failures which not covered above, like: // disk failure, kill by app master or resource manager, ... @@ -735,18 +742,6 @@ private[yarn] class YarnAllocator( private object YarnAllocator { val MEM_REGEX = "[0-9.]+ [KMG]B" - val PMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") - val VMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") val VMEM_EXCEEDED_EXIT_CODE = -103 val PMEM_EXCEEDED_EXIT_CODE = -104 - - def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { - val matcher = pattern.matcher(diagnostics) - val diag = if (matcher.find()) " " + matcher.group() + "." else "" - s"Container killed by YARN for exceeding memory limits. $diag " + - "Consider boosting spark.yarn.executor.memoryOverhead or " + - "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." - } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index c3070de3d17cf..b61e7df4420ef 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -29,7 +29,6 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.rpc.RpcEndpointRef @@ -376,17 +375,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } - test("memory exceeded diagnostic regexes") { - val diagnostics = - "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + - "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + - "5.8 GB of 4.2 GB virtual memory used. Killing container." - val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) - val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) - assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) - assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) - } - test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) From c34c42234f308872ebe9c7cdaee32000c0726eea Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 20 Nov 2018 08:29:59 -0600 Subject: [PATCH 080/145] [SPARK-26076][BUILD][MINOR] Revise ambiguous error message from load-spark-env.sh ## What changes were proposed in this pull request? When I try to run scripts (e.g. `start-master.sh`/`start-history-server.sh ` in latest master, I got such error: ``` Presence of build for multiple Scala versions detected. Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh. ``` The error message is quite confusing. Without reading `load-spark-env.sh`, I didn't know which directory to remove, or where to find and edit the `spark-evn.sh`. This PR is to make the error message more clear. Also change the script for less maintenance when we add or drop Scala versions in the future. As now with https://github.com/apache/spark/pull/22967, we can revise the error message as following(in my local setup): ``` Presence of build for multiple Scala versions detected (/Users/gengliangwang/IdeaProjects/spark/assembly/target/scala-2.12 and /Users/gengliangwang/IdeaProjects/spark/assembly/target/scala-2.11). Remove one of them or, export SPARK_SCALA_VERSION=2.12 in /Users/gengliangwang/IdeaProjects/spark/conf/spark-env.sh. Visit https://spark.apache.org/docs/latest/configuration.html#environment-variables for more details about setting environment variables in spark-env.sh. ``` ## How was this patch tested? Manual test Closes #23049 from gengliangwang/reviseEnvScript. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- bin/load-spark-env.sh | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 0b5006dbd63ac..0ada5d8d0fc1d 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -26,15 +26,17 @@ if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi +SPARK_ENV_SH="spark-env.sh" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}" - if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then + SPARK_ENV_SH="${SPARK_CONF_DIR}/${SPARK_ENV_SH}" + if [[ -f "${SPARK_ENV_SH}" ]]; then # Promote all variable declarations to environment (exported) variables set -a - . "${SPARK_CONF_DIR}/spark-env.sh" + . ${SPARK_ENV_SH} set +a fi fi @@ -42,19 +44,22 @@ fi # Setting SPARK_SCALA_VERSION if not already set. if [ -z "$SPARK_SCALA_VERSION" ]; then + SCALA_VERSION_1=2.12 + SCALA_VERSION_2=2.11 - ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" - ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12" - - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for multiple Scala versions detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh.' 1>&2 + ASSEMBLY_DIR_1="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_1}" + ASSEMBLY_DIR_2="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_2}" + ENV_VARIABLE_DOC="https://spark.apache.org/docs/latest/configuration.html#environment-variables" + if [[ -d "$ASSEMBLY_DIR_1" && -d "$ASSEMBLY_DIR_2" ]]; then + echo "Presence of build for multiple Scala versions detected ($ASSEMBLY_DIR_1 and $ASSEMBLY_DIR_2)." 1>&2 + echo "Remove one of them or, export SPARK_SCALA_VERSION=$SCALA_VERSION_1 in ${SPARK_ENV_SH}." 1>&2 + echo "Visit ${ENV_VARIABLE_DOC} for more details about setting environment variables in spark-env.sh." 1>&2 exit 1 fi - if [ -d "$ASSEMBLY_DIR2" ]; then - export SPARK_SCALA_VERSION="2.11" + if [[ -d "$ASSEMBLY_DIR_1" ]]; then + export SPARK_SCALA_VERSION=${SCALA_VERSION_1} else - export SPARK_SCALA_VERSION="2.12" + export SPARK_SCALA_VERSION=${SCALA_VERSION_2} fi fi From ab61ddb34d58ab5701191c8fd3a24a62f6ebf37b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 20 Nov 2018 08:56:22 -0600 Subject: [PATCH 081/145] [SPARK-26118][WEB UI] Introducing spark.ui.requestHeaderSize for setting HTTP requestHeaderSize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Introducing spark.ui.requestHeaderSize for configuring Jetty's HTTP requestHeaderSize. This way long authorization field does not lead to HTTP 413. ## How was this patch tested? Manually with curl (which version must be at least 7.55). With the original default value (8k limit): ```bash # Starting history server with default requestHeaderSize $ ./sbin/start-history-server.sh starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out # Creating huge header $ echo -n "X-Custom-Header: " > cookie $ printf 'A%.0s' {1..9500} >> cookie # HTTP GET with huge header fails with 431 $ curl -H cookie http://458apiros-MBP.lan:18080/

    Bad Message 431

    reason: Request Header Fields Too Large
    # The log contains the error $ tail -1 /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out 18/11/19 21:24:28 WARN HttpParser: Header is too large 8193>8192 ``` After: ```bash # Creating the history properties file with the increased requestHeaderSize $ echo spark.ui.requestHeaderSize=10000 > history.properties # Starting Spark History Server with the settings $ ./sbin/start-history-server.sh --properties-file history.properties starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out # HTTP GET with huge header gives back HTML5 (I have added here only just a part of the response) $ curl -H cookie http://458apiros-MBP.lan:18080/ ... History Server ... ``` Closes #23090 from attilapiros/JettyHeaderSize. Authored-by: “attilapiros” Signed-off-by: Imran Rashid --- .../scala/org/apache/spark/internal/config/package.scala | 6 ++++++ core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 6 ++++-- docs/configuration.md | 8 ++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ab2b872c5551e..9cc48f6375003 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -570,6 +570,12 @@ package object config { .stringConf .createOptional + private[spark] val UI_REQUEST_HEADER_SIZE = + ConfigBuilder("spark.ui.requestHeaderSize") + .doc("Value for HTTP request header size in bytes.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("8k") + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 52a955111231a..316af9b79d286 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -356,13 +356,15 @@ private[spark] object JettyUtils extends Logging { (connector, connector.getLocalPort()) } + val httpConfig = new HttpConfiguration() + httpConfig.setRequestHeaderSize(conf.get(UI_REQUEST_HEADER_SIZE).toInt) // If SSL is configured, create the secure connector first. val securePort = sslOptions.createJettySslContextFactory().map { factory => val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName val connectionFactories = AbstractConnectionFactory.getFactories(factory, - new HttpConnectionFactory()) + new HttpConnectionFactory(httpConfig)) def sslConnect(currentPort: Int): (ServerConnector, Int) = { newConnector(connectionFactories, currentPort) @@ -377,7 +379,7 @@ private[spark] object JettyUtils extends Logging { // Bind the HTTP port. def httpConnect(currentPort: Int): (ServerConnector, Int) = { - newConnector(Array(new HttpConnectionFactory()), currentPort) + newConnector(Array(new HttpConnectionFactory(httpConfig)), currentPort) } val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, diff --git a/docs/configuration.md b/docs/configuration.md index 2915fb5fa9197..04210d855b110 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -973,6 +973,14 @@ Apart from these, the following properties are also available, and may be useful
    spark.com.test.filter1.param.name2=bar + + spark.ui.requestHeaderSize + 8k + + The maximum allowed size for a HTTP request header, in bytes unless otherwise specified. + This setting applies for the Spark History Server too. + + ### Compression and Serialization From db136d360e54e13f1d7071a0428964a202cf7e31 Mon Sep 17 00:00:00 2001 From: Simeon Simeonov Date: Tue, 20 Nov 2018 21:29:56 +0100 Subject: [PATCH 082/145] [SPARK-26084][SQL] Fixes unresolved AggregateExpression.references exception ## What changes were proposed in this pull request? This PR fixes an exception in `AggregateExpression.references` called on unresolved expressions. It implements the solution proposed in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084), a minor refactoring that removes the unnecessary dependence on `AttributeSet.toSeq`, which requires expression IDs and, therefore, can only execute successfully for resolved expressions. The refactored implementation is both simpler and faster, eliminating the conversion of a `Set` to a `Seq` and back to `Set`. ## How was this patch tested? Added a new test based on the failing case in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084). hvanhovell Closes #23075 from ssimeonov/ss_SPARK-26084. Authored-by: Simeon Simeonov Signed-off-by: Herman van Hovell --- .../expressions/aggregate/interfaces.scala | 8 ++--- .../aggregate/AggregateExpressionSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e1d16a2cd38b0..56c2ee6b53fe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -128,12 +128,10 @@ case class AggregateExpression( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferences = mode match { - case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.aggBufferAttributes + mode match { + case Partial | Complete => aggregateFunction.references + case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } - - AttributeSet(childReferences) } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala new file mode 100644 index 0000000000000..8e9c9972071ad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet} + +class AggregateExpressionSuite extends SparkFunSuite { + + test("test references from unresolved aggregate functions") { + val x = UnresolvedAttribute("x") + val y = UnresolvedAttribute("y") + val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references + val expected = AttributeSet(x :: y :: Nil) + assert(expected == actual, s"Expected: $expected. Actual: $actual") + } + +} From 42c48387c047d96154bcfeb95fcb816a43e60d7c Mon Sep 17 00:00:00 2001 From: shane knapp Date: Tue, 20 Nov 2018 12:38:40 -0800 Subject: [PATCH 083/145] [BUILD] refactor dev/lint-python in to something readable ## What changes were proposed in this pull request? `dev/lint-python` is a mess of nearly unreadable bash. i would like to fix that as best as i can. ## How was this patch tested? the build system will test this. Closes #22994 from shaneknapp/lint-python-refactor. Authored-by: shane knapp Signed-off-by: shane knapp --- dev/lint-python | 359 +++++++++++++++++++++++++++++------------------- 1 file changed, 220 insertions(+), 139 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index 27d87f6b56680..06816932e754a 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -1,5 +1,4 @@ #!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -16,160 +15,242 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# define test binaries + versions +PYDOCSTYLE_BUILD="pydocstyle" +MINIMUM_PYDOCSTYLE="3.0.0" -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -# Exclude auto-generated configuration file. -PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -DOC_PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" | grep -vF 'functions.py' )" -PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" -PYDOCSTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pydocstyle-report.txt" -PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" -PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" - -PYDOCSTYLEBUILD="pydocstyle" -MINIMUM_PYDOCSTYLEVERSION="3.0.0" - -FLAKE8BUILD="flake8" +FLAKE8_BUILD="flake8" MINIMUM_FLAKE8="3.5.0" -SPHINXBUILD=${SPHINXBUILD:=sphinx-build} -SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" +PYCODESTYLE_BUILD="pycodestyle" +MINIMUM_PYCODESTYLE="2.4.0" -cd "$SPARK_ROOT_DIR" +SPHINX_BUILD="sphinx-build" -# compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" -compile_status="${PIPESTATUS[0]}" +function compile_python_test { + local COMPILE_STATUS= + local COMPILE_REPORT= + + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; + fi -# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. -PYCODESTYLE_VERSION="2.4.0" -PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" -PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" + # compileall: https://docs.python.org/2/library/compileall.html + echo "starting python compilation test..." + COMPILE_REPORT=$( (python -B -mcompileall -q -l $1) 2>&1) + COMPILE_STATUS=$? + + if [ $COMPILE_STATUS -ne 0 ]; then + echo "Python compilation failed with the following errors:" + echo "$COMPILE_REPORT" + echo "$COMPILE_STATUS" + exit "$COMPILE_STATUS" + else + echo "python compilation succeeded." + echo + fi +} -if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then - curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" - curl_status="$?" +function pycodestyle_test { + local PYCODESTYLE_STATUS= + local PYCODESTYLE_REPORT= + local RUN_LOCAL_PYCODESTYLE= + local VERSION= + local EXPECTED_PYCODESTYLE= + local PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$MINIMUM_PYCODESTYLE.py" + local PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$MINIMUM_PYCODESTYLE/pycodestyle.py" - if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." - exit "$curl_status" + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; fi -fi - -# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should -# be set to the directory. -# dev/pylint should be appended to the PATH variable as well. -# Jenkins by default installs the pylint3 version, so for now this just checks the code quality -# of python3. -export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" -export "PYLINT_HOME=$PYTHONPATH" -export "PATH=$PYTHONPATH:$PATH" - -# There is no need to write this output to a file -# first, but we do so so that the check status can -# be output before the report, like with the -# scalastyle and RAT checks. -python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" -pycodestyle_status="${PIPESTATUS[0]}" - -if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then - lint_status=0 -else - lint_status=1 -fi - -if [ "$lint_status" -ne 0 ]; then - echo "pycodestyle checks failed." - cat "$PYCODESTYLE_REPORT_PATH" - rm "$PYCODESTYLE_REPORT_PATH" - exit "$lint_status" -else - echo "pycodestyle checks passed." - rm "$PYCODESTYLE_REPORT_PATH" -fi - -# Check by flake8 -if hash "$FLAKE8BUILD" 2> /dev/null; then - FLAKE8VERSION="$( $FLAKE8BUILD --version 2> /dev/null )" - VERSION=($FLAKE8VERSION) - IS_EXPECTED_FLAKE8=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))' 2> /dev/null) - if [[ "$IS_EXPECTED_FLAKE8" == "True" ]]; then - # stop the build if there are Python syntax errors or undefined names - $FLAKE8BUILD . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics - flake8_status="${PIPESTATUS[0]}" - - if [ "$flake8_status" -eq 0 ]; then - lint_status=0 - else - lint_status=1 + + # check for locally installed pycodestyle & version + RUN_LOCAL_PYCODESTYLE="False" + if hash "$PYCODESTYLE_BUILD" 2> /dev/null; then + VERSION=$( $PYCODESTYLE_BUILD --version 2> /dev/null) + EXPECTED_PYCODESTYLE=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_PYCODESTYLE'"""))')\ + 2> /dev/null) + + if [ "$EXPECTED_PYCODESTYLE" == "True" ]; then + RUN_LOCAL_PYCODESTYLE="True" fi + fi - if [ "$lint_status" -ne 0 ]; then - echo "flake8 checks failed." - exit "$lint_status" - else - echo "flake8 checks passed." + # download the right version or run locally + if [ $RUN_LOCAL_PYCODESTYLE == "False" ]; then + # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. + # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 + # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. + echo "downloading pycodestyle from $PYCODESTYLE_SCRIPT_REMOTE_PATH..." + if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" + local curl_status="$?" + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pycodestyle.py from $PYCODESTYLE_SCRIPT_REMOTE_PATH" + exit "$curl_status" + fi fi + + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( (python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + else + # we have the right version installed, so run locally + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( ($PYCODESTYLE_BUILD --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + fi + + if [ $PYCODESTYLE_STATUS -ne 0 ]; then + echo "pycodestyle checks failed:" + echo "$PYCODESTYLE_REPORT" + exit "$PYCODESTYLE_STATUS" else - echo "The flake8 version needs to be "$MINIMUM_FLAKE8" at latest. Your current version is '"$FLAKE8VERSION"'." + echo "pycodestyle checks passed." + echo + fi +} + +function flake8_test { + local FLAKE8_VERSION= + local VERSION= + local EXPECTED_FLAKE8= + local FLAKE8_REPORT= + local FLAKE8_STATUS= + + if ! hash "$FLAKE8_BUILD" 2> /dev/null; then + echo "The flake8 command was not found." echo "flake8 checks failed." exit 1 fi -else - echo >&2 "The flake8 command was not found." - echo "flake8 checks failed." - exit 1 -fi - -# Check python document style, skip check if pydocstyle is not installed. -if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then - PYDOCSTYLEVERSION="$( $PYDOCSTYLEBUILD --version 2> /dev/null )" - IS_EXPECTED_PYDOCSTYLEVERSION=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'$PYDOCSTYLEVERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLEVERSION'"""))') - if [[ "$IS_EXPECTED_PYDOCSTYLEVERSION" == "True" ]]; then - $PYDOCSTYLEBUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" - pydocstyle_status="${PIPESTATUS[0]}" - - if [ "$compile_status" -eq 0 -a "$pydocstyle_status" -eq 0 ]; then - echo "pydocstyle checks passed." - rm "$PYDOCSTYLE_REPORT_PATH" - else - echo "pydocstyle checks failed." - cat "$PYDOCSTYLE_REPORT_PATH" - rm "$PYDOCSTYLE_REPORT_PATH" - exit 1 - fi + FLAKE8_VERSION="$($FLAKE8_BUILD --version 2> /dev/null)" + VERSION=($FLAKE8_VERSION) + EXPECTED_FLAKE8=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))') \ + 2> /dev/null) + + if [[ "$EXPECTED_FLAKE8" == "False" ]]; then + echo "\ +The minimum flake8 version needs to be $MINIMUM_FLAKE8. Your current version is $FLAKE8_VERSION + +flake8 checks failed." + exit 1 + fi + + echo "starting $FLAKE8_BUILD test..." + FLAKE8_REPORT=$( ($FLAKE8_BUILD . --count --select=E901,E999,F821,F822,F823 \ + --max-line-length=100 --show-source --statistics) 2>&1) + FLAKE8_STATUS=$? + + if [ "$FLAKE8_STATUS" -ne 0 ]; then + echo "flake8 checks failed:" + echo "$FLAKE8_REPORT" + echo "$FLAKE8_STATUS" + exit "$FLAKE8_STATUS" else - echo "The pydocstyle version needs to be "$MINIMUM_PYDOCSTYLEVERSION" at latest. Your current version is "$PYDOCSTYLEVERSION". Skipping pydoc checks for now." + echo "flake8 checks passed." + echo fi -else - echo >&2 "The pydocstyle command was not found. Skipping pydoc checks for now" -fi - -# Check that the documentation builds acceptably, skip check if sphinx is not installed. -if hash "$SPHINXBUILD" 2> /dev/null; then - cd python/docs - make clean - # Treat warnings as errors so we stop correctly - SPHINXOPTS="-a -W" make html &> "$SPHINX_REPORT_PATH" || lint_status=1 - if [ "$lint_status" -ne 0 ]; then - echo "pydoc checks failed." - cat "$SPHINX_REPORT_PATH" - echo "re-running make html to print full warning list" - make clean - SPHINXOPTS="-a" make html - rm "$SPHINX_REPORT_PATH" - exit "$lint_status" - else - echo "pydoc checks passed." - rm "$SPHINX_REPORT_PATH" - fi - cd ../.. -else - echo >&2 "The $SPHINXBUILD command was not found. Skipping pydoc checks for now" -fi +} + +function pydocstyle_test { + local PYDOCSTYLE_REPORT= + local PYDOCSTYLE_STATUS= + local PYDOCSTYLE_VERSION= + local EXPECTED_PYDOCSTYLE= + + # Exclude auto-generated configuration file. + local DOC_PATHS_TO_CHECK="$( cd "${SPARK_ROOT_DIR}" && find . -name "*.py" | grep -vF 'functions.py' )" + + # Check python document style, skip check if pydocstyle is not installed. + if ! hash "$PYDOCSTYLE_BUILD" 2> /dev/null; then + echo "The pydocstyle command was not found. Skipping pydocstyle checks for now." + echo + return + fi + + PYDOCSTYLE_VERSION="$($PYDOCSTYLEBUILD --version 2> /dev/null)" + EXPECTED_PYDOCSTYLE=$(python -c 'from distutils.version import LooseVersion; \ + print(LooseVersion("""'$PYDOCSTYLE_VERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLE'"""))' \ + 2> /dev/null) + + if [[ "$EXPECTED_PYDOCSTYLE" == "False" ]]; then + echo "\ +The minimum version of pydocstyle needs to be $MINIMUM_PYDOCSTYLE. +Your current version is $PYDOCSTYLE_VERSION. +Skipping pydocstyle checks for now." + echo + return + fi + + echo "starting $PYDOCSTYLE_BUILD test..." + PYDOCSTYLE_REPORT=$( ($PYDOCSTYLE_BUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK) 2>&1) + PYDOCSTYLE_STATUS=$? + + if [ "$PYDOCSTYLE_STATUS" -ne 0 ]; then + echo "pydocstyle checks failed:" + echo "$PYDOCSTYLE_REPORT" + exit "$PYDOCSTYLE_STATUS" + else + echo "pydocstyle checks passed." + echo + fi +} + +function sphinx_test { + local SPHINX_REPORT= + local SPHINX_STATUS= + + # Check that the documentation builds acceptably, skip check if sphinx is not installed. + if ! hash "$SPHINX_BUILD" 2> /dev/null; then + echo "The $SPHINX_BUILD command was not found. Skipping pydoc checks for now." + echo + return + fi + + echo "starting $SPHINX_BUILD tests..." + pushd python/docs &> /dev/null + make clean &> /dev/null + # Treat warnings as errors so we stop correctly + SPHINX_REPORT=$( (SPHINXOPTS="-a -W" make html) 2>&1) + SPHINX_STATUS=$? + + if [ "$SPHINX_STATUS" -ne 0 ]; then + echo "$SPHINX_BUILD checks failed:" + echo "$SPHINX_REPORT" + echo + echo "re-running make html to print full warning list:" + make clean &> /dev/null + SPHINX_REPORT=$( (SPHINXOPTS="-a" make html) 2>&1) + echo "$SPHINX_REPORT" + exit "$SPHINX_STATUS" + else + echo "$SPHINX_BUILD checks passed." + echo + fi + + popd &> /dev/null +} + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")" + +pushd "$SPARK_ROOT_DIR" &> /dev/null + +PYTHON_SOURCE="$(find . -name "*.py")" + +compile_python_test "$PYTHON_SOURCE" +pycodestyle_test "$PYTHON_SOURCE" +flake8_test +pydocstyle_test +sphinx_test + +echo +echo "all lint-python tests passed!" + +popd &> /dev/null From 23bcd6ce458f1e49f307c89ca2794dc9a173077c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 20 Nov 2018 18:03:54 -0600 Subject: [PATCH 084/145] [SPARK-26043][HOTFIX] Hotfix a change to SparkHadoopUtil that doesn't work in 2.11 ## What changes were proposed in this pull request? Hotfix a change to SparkHadoopUtil that doesn't work in 2.11 ## How was this patch tested? Existing tests. Closes #23097 from srowen/SPARK-26043.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 217e5145f1c56..7bb2a419107d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Date, Locale} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -269,10 +269,11 @@ private[spark] class SparkHadoopUtil extends Logging { name.startsWith(prefix) && !name.endsWith(exclusionSuffix) } }) - Arrays.sort(fileStatuses, - (o1: FileStatus, o2: FileStatus) => { + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { Longs.compare(o1.getModificationTime, o2.getModificationTime) - }) + } + }) fileStatuses } catch { case NonFatal(e) => From 47851056c20c5d981b1ca66bac3f00c19a882727 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 20 Nov 2018 18:05:39 -0600 Subject: [PATCH 085/145] [SPARK-26124][BUILD] Update plugins to latest versions ## What changes were proposed in this pull request? Update many plugins we use to the latest version, especially MiMa, which entails excluding some new errors on old changes. ## How was this patch tested? N/A Closes #23087 from srowen/Plugins. Authored-by: Sean Owen Signed-off-by: Sean Owen --- pom.xml | 40 +++++++++++++++++++++++--------------- project/MimaExcludes.scala | 10 +++++++++- project/plugins.sbt | 14 ++++++------- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/pom.xml b/pom.xml index 9130773cb5094..08a29d2d52310 100644 --- a/pom.xml +++ b/pom.xml @@ -1977,7 +1977,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 3.0.0-M1 + 3.0.0-M2 enforce-versions @@ -2077,7 +2077,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.7.0 + 3.8.0 ${java.version} ${java.version} @@ -2094,7 +2094,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.22.0 + 3.0.0-M1 @@ -2148,7 +2148,7 @@ org.scalatest scalatest-maven-plugin - 1.0 + 2.0.0 ${project.build.directory}/surefire-reports @@ -2195,7 +2195,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.0.2 + 3.1.0 org.apache.maven.plugins @@ -2222,7 +2222,7 @@ org.apache.maven.plugins maven-clean-plugin - 3.0.0 + 3.1.0 @@ -2240,9 +2240,12 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.0.0-M1 + 3.0.1 - -Xdoclint:all -Xdoclint:-missing + + -Xdoclint:all + -Xdoclint:-missing + example @@ -2293,7 +2296,7 @@ org.apache.maven.plugins maven-shade-plugin - 3.2.0 + 3.2.1 org.ow2.asm @@ -2310,12 +2313,12 @@ org.apache.maven.plugins maven-install-plugin - 2.5.2 + 3.0.0-M1 org.apache.maven.plugins maven-deploy-plugin - 2.8.2 + 3.0.0-M1 org.apache.maven.plugins @@ -2361,7 +2364,7 @@ org.apache.maven.plugins maven-jar-plugin - [2.6,) + 3.1.0 test-jar @@ -2518,12 +2521,17 @@ org.apache.maven.plugins maven-checkstyle-plugin - 2.17 + 3.0.0 false true - ${basedir}/src/main/java,${basedir}/src/main/scala - ${basedir}/src/test/java + + ${basedir}/src/main/java + ${basedir}/src/main/scala + + + ${basedir}/src/test/java + dev/checkstyle.xml ${basedir}/target/checkstyle-output.xml ${project.build.sourceEncoding} @@ -2533,7 +2541,7 @@ com.puppycrawl.tools checkstyle - 8.2 + 8.14 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e35e74aa33045..b750535e8a70b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,7 +36,15 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 + // [SPARK-26124] Update plugins, including MiMa + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters.build"), + + // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), diff --git a/project/plugins.sbt b/project/plugins.sbt index ffbd417b0f145..c9354735a62f5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,7 +1,7 @@ addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") // sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. -libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.14" // checkstyle uses guava 23.0. libraryDependencies += "com.google.guava" % "guava" % "23.0" @@ -9,13 +9,13 @@ libraryDependencies += "com.google.guava" % "guava" % "23.0" // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4") -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0") // sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -28,12 +28,12 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") -libraryDependencies += "org.ow2.asm" % "asm" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm" % "7.0" -libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm-commons" % "7.0" // sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14 -addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.12") // Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues // related to test-jar dependencies (https://github.com/sbt/sbt-pom-reader/pull/14). The source for From 2df34db586bec379e40b5cf30021f5b7a2d79271 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 21 Nov 2018 09:29:22 +0800 Subject: [PATCH 086/145] [SPARK-26122][SQL] Support encoding for multiLine in CSV datasource ## What changes were proposed in this pull request? In the PR, I propose to pass the CSV option `encoding`/`charset` to `uniVocity` parser to allow parsing CSV files in different encodings when `multiLine` is enabled. The value of the option is passed to the `beginParsing` method of `CSVParser`. ## How was this patch tested? Added new test to `CSVSuite` for different encodings and enabled/disabled header. Closes #23091 from MaxGekk/csv-miltiline-encoding. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/csv/UnivocityParser.scala | 12 ++++++----- .../datasources/csv/CSVDataSource.scala | 6 ++++-- .../execution/datasources/csv/CSVSuite.scala | 21 +++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 46ed58ed92830..ed196935e357f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -271,11 +271,12 @@ private[sql] object UnivocityParser { def tokenizeStream( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser): Iterator[Array[String]] = { + tokenizer: CsvParser, + encoding: String): Iterator[Array[String]] = { val handleHeader: () => Unit = () => if (shouldDropHeader) tokenizer.parseNext - convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens) + convertStream(inputStream, tokenizer, handleHeader, encoding)(tokens => tokens) } /** @@ -297,7 +298,7 @@ private[sql] object UnivocityParser { val handleHeader: () => Unit = () => headerChecker.checkHeaderColumnNames(tokenizer) - convertStream(inputStream, tokenizer, handleHeader) { tokens => + convertStream(inputStream, tokenizer, handleHeader, parser.options.charset) { tokens => safeParser.parse(tokens) }.flatten } @@ -305,9 +306,10 @@ private[sql] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, tokenizer: CsvParser, - handleHeader: () => Unit)( + handleHeader: () => Unit, + encoding: String)( convert: Array[String] => T) = new Iterator[T] { - tokenizer.beginParsing(inputStream) + tokenizer.beginParsing(inputStream, encoding) // We can handle header here since here the stream is open. handleHeader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4808e8ef042d1..554baaf1a9b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -192,7 +192,8 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) }.take(1).headOption match { case Some(firstRow) => val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -203,7 +204,8 @@ object MultiLineCSVDataSource extends CSVDataSource { lines.getConfiguration, new Path(lines.getPath())), parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) CSVInferSchema.infer(sampled, header, parsedOptions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2efe1dda475c5..e29cd2aa7c4e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1859,4 +1859,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(df, Row(null, csv)) } } + + test("encoding in multiLine mode") { + val df = spark.range(3).toDF() + Seq("UTF-8", "ISO-8859-1", "CP1251", "US-ASCII", "UTF-16BE", "UTF-32LE").foreach { encoding => + Seq(true, false).foreach { header => + withTempPath { path => + df.write + .option("encoding", encoding) + .option("header", header) + .csv(path.getCanonicalPath) + val readback = spark.read + .option("multiLine", true) + .option("encoding", encoding) + .option("inferSchema", true) + .option("header", header) + .csv(path.getCanonicalPath) + checkAnswer(readback, df) + } + } + } + } } From 4b7f7ef5007c2c8a5090f22c6e08927e9f9a407b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Nov 2018 09:31:12 +0800 Subject: [PATCH 087/145] [SPARK-26120][TESTS][SS][SPARKR] Fix a streaming query leak in Structured Streaming R tests ## What changes were proposed in this pull request? Stop the streaming query in `Specify a schema by using a DDL-formatted string when reading` to avoid outputting annoying logs. ## How was this patch tested? Jenkins Closes #23089 from zsxwing/SPARK-26120. Authored-by: Shixiong Zhu Signed-off-by: hyukjinkwon --- R/pkg/tests/fulltests/test_streaming.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index bfb1a046490ec..6f0d2aefee886 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -127,6 +127,7 @@ test_that("Specify a schema by using a DDL-formatted string when reading", { expect_false(awaitTermination(q, 5 * 1000)) callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + stopQuery(q) expect_error(read.stream(path = parquetPath, schema = "name stri"), "DataType stri is not supported.") From a480a6256318b43b963fb7414ccb789e4b950c8b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 21 Nov 2018 00:24:34 -0800 Subject: [PATCH 088/145] [SPARK-25954][SS] Upgrade to Kafka 2.1.0 ## What changes were proposed in this pull request? [Kafka 2.1.0 vote](https://lists.apache.org/thread.html/9f487094491e512b556a1c9c3c6034ac642b088e3f797e3d192ebc9d%3Cdev.kafka.apache.org%3E) passed. Since Kafka 2.1.0 includes official JDK 11 support [KAFKA-7264](https://issues.apache.org/jira/browse/KAFKA-7264), we had better use that. ## How was this patch tested? Pass the Jenkins. Closes #23099 from dongjoon-hyun/SPARK-25954. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 3f1055a75076f..d97e8cf18605e 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -30,7 +30,7 @@ sql-kafka-0-10 - 2.0.0 + 2.1.0 jar Kafka 0.10+ Source for Structured Streaming diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index d75b13da8fb70..cfc45559d8e34 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -29,7 +29,7 @@ streaming-kafka-0-10 - 2.0.0 + 2.1.0 jar Spark Integration for Kafka 0.10 From 540afc2b18ef61cceb50b9a5b327e6fcdbe1e7e4 Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 21 Nov 2018 09:31:35 -0600 Subject: [PATCH 089/145] [SPARK-26109][WEBUI] Duration in the task summary metrics table and the task table are different ## What changes were proposed in this pull request? Task summary table displays the summary of the task table in the stage page. However, the 'Duration' metrics of 'task summary' table and 'task table' are not matching. The reason is because, in the 'task summary' we display 'executorRunTime' as the duration, and in the 'task table' the actual duration of the task. Except duration metrics, all other metrics are properly displaying in the task summary. In Spark2.2, used to show 'executorRunTime' as duration in the 'taskTable'. That is why, in summary metrics also the 'exeuctorRunTime' shows as the duration. So, we need to show 'executorRunTime' as the duration in the tasks table to follow the same behaviour as the previous versions of spark. ## How was this patch tested? Before patch: ![screenshot from 2018-11-19 04-32-06](https://user-images.githubusercontent.com/23054875/48679263-1e4fff80-ebb4-11e8-9ed5-16d892039e01.png) After patch: ![screenshot from 2018-11-19 04-37-39](https://user-images.githubusercontent.com/23054875/48679343-e39a9700-ebb4-11e8-8df9-9dc3a28d4bce.png) Closes #23081 from shahidki31/duratinSummary. Authored-by: Shahid Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 477b9ce7f7848..7e6cc4297d6b1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -843,7 +843,7 @@ private[ui] class TaskPagedTable( {UIUtils.formatDate(task.launchTime)} - {formatDuration(task.duration)} + {formatDuration(task.taskMetrics.map(_.executorRunTime))} {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} @@ -996,7 +996,9 @@ private[ui] object ApiHelper { HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, HEADER_HOST -> TaskIndexNames.HOST, HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, - HEADER_DURATION -> TaskIndexNames.DURATION, + // SPARK-26109: Duration of task as executorRunTime to make it consistent with the + // aggregated tasks summary metrics table and the previous versions of Spark. + HEADER_DURATION -> TaskIndexNames.EXEC_RUN_TIME, HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, HEADER_GC_TIME -> TaskIndexNames.GC_TIME, From 6bbdf34baed7b2bab1fbfbce7782b3093a72812f Mon Sep 17 00:00:00 2001 From: Drew Robb Date: Wed, 21 Nov 2018 09:38:06 -0600 Subject: [PATCH 090/145] [SPARK-8288][SQL] ScalaReflection can use companion object constructor ## What changes were proposed in this pull request? This change fixes a particular scenario where default spark SQL can't encode (thrift) types that are generated by twitter scrooge. These types are a trait that extends `scala.ProductX` with a constructor defined only in a companion object, rather than a actual case class. The actual case class used is child class, but that type is almost never referred to in code. The type has no corresponding constructor symbol and causes an exception. For all other purposes, these classes act just like case classes, so it is unfortunate that spark SQL can't serialize them nicely as it can actual case classes. For an full example of a scrooge codegen class, see https://gist.github.com/anonymous/ba13d4b612396ca72725eaa989900314. This change catches the case where the type has no constructor but does have an `apply` method on the type's companion object. This allows for thrift types to be serialized/deserialized with implicit encoders the same way as normal case classes. This fix had to be done in three places where the constructor is assumed to be an actual constructor: 1) In serializing, determining the schema for the dataframe relies on inspecting its constructor (`ScalaReflection.constructParams`). Here we fall back to using the companion constructor arguments. 2) In deserializing or evaluating, in the java codegen ( `NewInstance.doGenCode`), the type couldn't be constructed with the new keyword. If there is no constructor, we change the constructor call to try the companion constructor. 3) In deserializing or evaluating, without codegen, the constructor is directly invoked (`NewInstance.constructor`). This was fixed with scala reflection to get the actual companion apply method. The return type of `findConstructor` was changed because the companion apply method constructor can't be represented as a `java.lang.reflect.Constructor`. There might be situations in which this approach would also fail in a new way, but it does at a minimum work for the specific scrooge example and will not impact cases that were already succeeding prior to this change Note: this fix does not enable using scrooge thrift enums, additional work for this is necessary. With this patch, it seems like you could patch `com.twitter.scrooge.ThriftEnum` to extend `_root_.scala.Product1[Int]` with `def _1 = value` to get spark's implicit encoders to handle enums, but I've yet to use this method myself. Note: I previously opened a PR for this issue, but only was able to fix case 1) there: https://github.com/apache/spark/pull/18766 ## How was this patch tested? I've fixed all 3 cases and added two tests that use a case class that is similar to scrooge generated one. The test in ScalaReflectionSuite checks 1), and the additional asserting in ObjectExpressionsSuite checks 2) and 3). Closes #23062 from drewrobb/SPARK-8288. Authored-by: Drew Robb Signed-off-by: Sean Owen --- .../spark/sql/catalyst/ScalaReflection.scala | 48 ++++++++++++++++--- .../expressions/objects/objects.scala | 18 ++++--- .../sql/catalyst/ScalaReflectionSuite.scala | 31 ++++++++++++ .../expressions/ObjectExpressionsSuite.scala | 10 ++++ .../org/apache/spark/sql/DatasetSuite.scala | 8 ++++ 5 files changed, 103 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 64ea236532839..c8542d0f2f7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -788,12 +788,37 @@ object ScalaReflection extends ScalaReflection { } /** - * Finds an accessible constructor with compatible parameters. This is a more flexible search - * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible - * matching constructor is returned. Otherwise, it returns `None`. + * Finds an accessible constructor with compatible parameters. This is a more flexible search than + * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned if it exists. Otherwise, we check for additional compatible + * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`. */ - def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { - Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match { + case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T]) + case None => + val companion = mirror.staticClass(cls.getName).companion + val moduleMirror = mirror.reflectModule(companion.asModule) + val applyMethods = companion.asTerm.typeSignature + .member(universe.TermName("apply")).asTerm.alternatives + applyMethods.find { method => + val params = method.typeSignature.paramLists.head + // Check that the needed params are the same length and of matching types + params.size == paramTypes.tail.size && + params.zip(paramTypes.tail).forall { case(ps, pc) => + ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + } + }.map { applyMethodSymbol => + val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size + val instanceMirror = mirror.reflect(moduleMirror.instance) + val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) + (_args: Seq[AnyRef]) => { + // Drop the "outer" argument if it is provided + val args = if (_args.size == expectedArgsCount) _args else _args.tail + method.apply(args: _*).asInstanceOf[T] + } + } + } } /** @@ -973,8 +998,19 @@ trait ScalaReflection extends Logging { } } + /** + * If our type is a Scala trait it may have a companion object that + * only defines a constructor via `apply` method. + */ + private def getCompanionConstructor(tpe: Type): Symbol = { + tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply")) + } + protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR) + val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { + case NoSymbol => getCompanionConstructor(tpe) + case sym => sym + } val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramLists } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4fd36a47cef52..59c897b6a53ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -462,12 +462,12 @@ case class NewInstance( val d = outerObj.getClass +: paramTypes val c = getConstructor(outerObj.getClass +: paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(outerObj +: args: _*) + c(outerObj +: args) } }.getOrElse { val c = getConstructor(paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(args: _*) + c(args) } } } @@ -486,10 +486,16 @@ case class NewInstance( ev.isNull = resultIsNull - val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" - }.getOrElse { - s"new $className($argString)" + val constructorCall = cls.getConstructors.size match { + // If there are no constructors, the `new` method will fail. In + // this case we can try to call the apply method constructor + // that might be defined on the companion object. + case 0 => s"$className$$.MODULE$$.apply($argString)" + case _ => outer.map { gen => + s"${gen.value}.new ${cls.getSimpleName}($argString)" + }.getOrElse { + s"new $className($argString)" + } } val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d98589db323cc..80824cc2a7f21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -109,6 +109,30 @@ object TestingUDT { } } +/** An example derived from Twitter/Scrooge codegen for thrift */ +object ScroogeLikeExample { + def apply(x: Int): ScroogeLikeExample = new Immutable(x) + + def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x) + + class Immutable(val x: Int) extends ScroogeLikeExample +} + +trait ScroogeLikeExample extends Product1[Int] with Serializable { + import ScroogeLikeExample._ + + def x: Int + + def _1: Int = x + + override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample] + + override def equals(other: Any): Boolean = + canEqual(other) && + this.x == other.asInstanceOf[ScroogeLikeExample].x + + override def hashCode: Int = x +} class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ @@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } + + test("SPARK-8288: schemaFor works for a class with only a companion object constructor") { + val schema = schemaFor[ScroogeLikeExample] + assert(schema === Schema( + StructType(Seq( + StructField("x", IntegerType, nullable = false))), nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 16842c1bcc8cb..436675bf50353 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { dataType = ObjectType(classOf[outerObj.Inner]), outerPointer = Some(() => outerObj)) checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + + // SPARK-8288: A class with only a companion object constructor + val newInst3 = NewInstance( + cls = classOf[ScroogeLikeExample], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[ScroogeLikeExample]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1)) } test("LambdaVariable should support interpreted execution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ac677e8ec6bc2..540fbff6a3a63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1570,6 +1571,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long]) checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) } + + test("SPARK-8288: class with only a companion object constructor") { + val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) + val ds = data.toDS + checkDataset(ds, data: _*) + checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 07a700b3711057553dfbb7b047216565726509c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 21 Nov 2018 16:41:12 +0100 Subject: [PATCH 091/145] [SPARK-26129][SQL] Instrumentation for per-query planning time ## What changes were proposed in this pull request? We currently don't have good visibility into query planning time (analysis vs optimization vs physical planning). This patch adds a simple utility to track the runtime of various rules and various planning phases. ## How was this patch tested? Added unit tests and end-to-end integration tests. Closes #23096 from rxin/SPARK-26129. Authored-by: Reynold Xin Signed-off-by: Reynold Xin --- .../sql/catalyst/QueryPlanningTracker.scala | 127 ++++++++++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 22 +-- .../sql/catalyst/rules/RuleExecutor.scala | 19 ++- .../catalyst/QueryPlanningTrackerSuite.scala | 78 +++++++++++ .../sql/catalyst/analysis/AnalysisTest.scala | 3 +- .../ResolveGroupingAnalyticsSuite.scala | 3 +- .../ResolvedUuidExpressionsSuite.scala | 10 +- .../scala/org/apache/spark/sql/Dataset.scala | 9 ++ .../org/apache/spark/sql/SparkSession.scala | 6 +- .../spark/sql/execution/QueryExecution.scala | 21 ++- .../QueryPlanningTrackerEndToEndSuite.scala | 52 +++++++ .../apache/spark/sql/hive/test/TestHive.scala | 16 ++- 12 files changed, 338 insertions(+), 28 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala new file mode 100644 index 0000000000000..420f2a1f20997 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.collection.JavaConverters._ + +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * A simple utility for tracking runtime and associated stats in query planning. + * + * There are two separate concepts we track: + * + * 1. Phases: These are broad scope phases in query planning, as listed below, i.e. analysis, + * optimizationm and physical planning (just planning). + * + * 2. Rules: These are the individual Catalyst rules that we track. In addition to time, we also + * track the number of invocations and effective invocations. + */ +object QueryPlanningTracker { + + // Define a list of common phases here. + val PARSING = "parsing" + val ANALYSIS = "analysis" + val OPTIMIZATION = "optimization" + val PLANNING = "planning" + + class RuleSummary( + var totalTimeNs: Long, var numInvocations: Long, var numEffectiveInvocations: Long) { + + def this() = this(totalTimeNs = 0, numInvocations = 0, numEffectiveInvocations = 0) + + override def toString: String = { + s"RuleSummary($totalTimeNs, $numInvocations, $numEffectiveInvocations)" + } + } + + /** + * A thread local variable to implicitly pass the tracker around. This assumes the query planner + * is single-threaded, and avoids passing the same tracker context in every function call. + */ + private val localTracker = new ThreadLocal[QueryPlanningTracker]() { + override def initialValue: QueryPlanningTracker = null + } + + /** Returns the current tracker in scope, based on the thread local variable. */ + def get: Option[QueryPlanningTracker] = Option(localTracker.get()) + + /** Sets the current tracker for the execution of function f. We assume f is single-threaded. */ + def withTracker[T](tracker: QueryPlanningTracker)(f: => T): T = { + val originalTracker = localTracker.get() + localTracker.set(tracker) + try f finally { localTracker.set(originalTracker) } + } +} + + +class QueryPlanningTracker { + + import QueryPlanningTracker._ + + // Mapping from the name of a rule to a rule's summary. + // Use a Java HashMap for less overhead. + private val rulesMap = new java.util.HashMap[String, RuleSummary] + + // From a phase to time in ns. + private val phaseToTimeNs = new java.util.HashMap[String, Long] + + /** Measure the runtime of function f, and add it to the time for the specified phase. */ + def measureTime[T](phase: String)(f: => T): T = { + val startTime = System.nanoTime() + val ret = f + val timeTaken = System.nanoTime() - startTime + phaseToTimeNs.put(phase, phaseToTimeNs.getOrDefault(phase, 0) + timeTaken) + ret + } + + /** + * Record a specific invocation of a rule. + * + * @param rule name of the rule + * @param timeNs time taken to run this invocation + * @param effective whether the invocation has resulted in a plan change + */ + def recordRuleInvocation(rule: String, timeNs: Long, effective: Boolean): Unit = { + var s = rulesMap.get(rule) + if (s eq null) { + s = new RuleSummary + rulesMap.put(rule, s) + } + + s.totalTimeNs += timeNs + s.numInvocations += 1 + s.numEffectiveInvocations += (if (effective) 1 else 0) + } + + // ------------ reporting functions below ------------ + + def rules: Map[String, RuleSummary] = rulesMap.asScala.toMap + + def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap + + /** Returns the top k most expensive rules (as measured by time). */ + def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = { + val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) + val q = new BoundedPriorityQueue(k)(orderingByTime) + rulesMap.asScala.foreach(q.+=) + q.toSeq.sortBy(r => -r._2.totalTimeNs) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ab2312fdcdeef..b977fa07db5c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,16 +102,18 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { - val analyzed = execute(plan) - try { - checkAnalysis(analyzed) - analyzed - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae + def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { + AnalysisHelper.markInAnalyzer { + val analyzed = executeAndTrack(plan, tracker) + try { + checkAnalysis(analyzed) + analyzed + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e991a2dc7462f..cf6ff4f986399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.rules import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -66,6 +67,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ protected def isPlanIntegral(plan: TreeType): Boolean = true + /** + * Executes the batches of rules defined by the subclass, and also tracks timing info for each + * rule using the provided tracker. + * @see [[execute]] + */ + def executeAndTrack(plan: TreeType, tracker: QueryPlanningTracker): TreeType = { + QueryPlanningTracker.withTracker(tracker) { + execute(plan) + } + } + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -74,6 +86,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { var curPlan = plan val queryExecutionMetrics = RuleExecutor.queryExecutionMeter val planChangeLogger = new PlanChangeLogger() + val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get batches.foreach { batch => val batchStartPlan = curPlan @@ -88,8 +101,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime + val effective = !result.fastEquals(plan) - if (!result.fastEquals(plan)) { + if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) planChangeLogger.log(rule.ruleName, plan, result) @@ -97,6 +111,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) queryExecutionMetrics.incNumExecution(rule.ruleName) + // Record timing information using QueryPlanningTracker + tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) + // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala new file mode 100644 index 0000000000000..f42c262dfbdd8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite + +class QueryPlanningTrackerSuite extends SparkFunSuite { + + test("phases") { + val t = new QueryPlanningTracker + t.measureTime("p1") { + Thread.sleep(1) + } + + assert(t.phases("p1") > 0) + assert(!t.phases.contains("p2")) + + val old = t.phases("p1") + + t.measureTime("p1") { + Thread.sleep(1) + } + assert(t.phases("p1") > old) + } + + test("rules") { + val t = new QueryPlanningTracker + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r3", 1, effective = false) + t.recordRuleInvocation("r3", 2, effective = true) + + val rules = t.rules + + assert(rules("r1").totalTimeNs == 1) + assert(rules("r1").numInvocations == 1) + assert(rules("r1").numEffectiveInvocations == 0) + + assert(rules("r2").totalTimeNs == 2) + assert(rules("r2").numInvocations == 1) + assert(rules("r2").numEffectiveInvocations == 1) + + assert(rules("r3").totalTimeNs == 3) + assert(rules("r3").numInvocations == 2) + assert(rules("r3").numEffectiveInvocations == 1) + } + + test("topRulesByTime") { + val t = new QueryPlanningTracker + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r4", 4, effective = true) + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r3", 3, effective = false) + + val top = t.topRulesByTime(2) + assert(top.size == 2) + assert(top(0)._1 == "r4") + assert(top(1)._1 == "r3") + + // Don't crash when k > total size + assert(t.topRulesByTime(10).size == 4) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3d7c91870133b..fab1b776a3c72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -21,6 +21,7 @@ import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,7 +55,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.executeAndCheck(inputPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan, new QueryPlanningTracker) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 8da4d7e3aa372..aa5eda8e5ba87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -109,7 +110,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), unresolved_b, UnresolvedAlias(count(unresolved_c)))) - val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2, new QueryPlanningTracker) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions assert(gExpressions.size == 3) val firstGroupingExprAttrName = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala index fe57c199b8744..64bd07534b19b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +35,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { private lazy val uuid3 = Uuid().as('_uuid3) private lazy val uuid1Ref = uuid1.toAttribute + private val tracker = new QueryPlanningTracker private val analyzer = getAnalyzer(caseSensitive = true) private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { @@ -47,7 +49,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("analyzed plan sets random seed for Uuid expression") { val plan = r.select(a, uuid1) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) getUuidExpressions(resolvedPlan).foreach { u => assert(u.resolved) assert(u.randomSeed.isDefined) @@ -56,14 +58,14 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("Uuid expressions should have different random seeds") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) } test("Different analyzed plans should have different random seeds in Uuids") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan1 = analyzer.executeAndCheck(plan) - val resolvedPlan2 = analyzer.executeAndCheck(plan) + val resolvedPlan1 = analyzer.executeAndCheck(plan, tracker) + val resolvedPlan2 = analyzer.executeAndCheck(plan, tracker) val uuids1 = getUuidExpressions(resolvedPlan1) val uuids2 = getUuidExpressions(resolvedPlan2) assert(uuids1.distinct.length == 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e77ec0406257..e757921b485df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ @@ -76,6 +77,14 @@ private[sql] object Dataset { qe.assertAnalyzed() new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) } + + /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker) + : DataFrame = { + val qe = new QueryExecution(sparkSession, logicalPlan, tracker) + qe.assertAnalyzed() + new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 725db97df4ed1..739c6b54b4cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -648,7 +648,11 @@ class SparkSession private( * @since 2.0.0 */ def sql(sqlText: String): DataFrame = { - Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText)) + val tracker = new QueryPlanningTracker + val plan = tracker.measureTime(QueryPlanningTracker.PARSING) { + sessionState.sqlParser.parsePlan(sqlText) + } + Dataset.ofRows(self, plan, tracker) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 905d035b64275..87a4ceb91aae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils * While this is not a public class, we should avoid changing the function names for the sake of * changing them, because a lot of developers use the feature for debugging. */ -class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { +class QueryExecution( + val sparkSession: SparkSession, + val logical: LogicalPlan, + val tracker: QueryPlanningTracker = new QueryPlanningTracker) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner @@ -56,9 +59,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } } - lazy val analyzed: LogicalPlan = { + lazy val analyzed: LogicalPlan = tracker.measureTime(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } lazy val withCachedData: LogicalPlan = { @@ -67,9 +70,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { sparkSession.sharedState.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = tracker.measureTime(QueryPlanningTracker.OPTIMIZATION) { + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) + } - lazy val sparkPlan: SparkPlan = { + lazy val sparkPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. @@ -78,7 +83,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { + prepareForExecution(sparkPlan) + } /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala new file mode 100644 index 0000000000000..0af4c85400e9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.test.SharedSQLContext + +class QueryPlanningTrackerEndToEndSuite extends SharedSQLContext { + + test("programmatic API") { + val df = spark.range(1000).selectExpr("count(*)") + df.collect() + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 3) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + + test("sql") { + val df = spark.sql("select * from range(1)") + df.collect() + + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 4) + assert(tracker.phases("parsing") > 0) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 634b3db19ec27..3508affda241a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -33,9 +33,9 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} @@ -219,6 +219,16 @@ private[hive] class TestHiveSparkSession( sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() } + /** + * This is a temporary hack to override SparkSession.sql so we can still use the version of + * Dataset.ofRows that creates a TestHiveQueryExecution (rather than a normal QueryExecution + * which wouldn't load all the test tables). + */ + override def sql(sqlText: String): DataFrame = { + val plan = sessionState.sqlParser.parsePlan(sqlText) + Dataset.ofRows(self, plan) + } + override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) } @@ -586,7 +596,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } } From 81550b38e43fb20f89f529d2127575c71a54a538 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 21 Nov 2018 11:16:54 -0800 Subject: [PATCH 092/145] [SPARK-26066][SQL] Move truncatedString to sql/catalyst and add spark.sql.debug.maxToStringFields conf ## What changes were proposed in this pull request? In the PR, I propose: - new SQL config `spark.sql.debug.maxToStringFields` to control maximum number fields up to which `truncatedString` cuts its input sequences. - Moving `truncatedString` out of `core` to `sql/catalyst` because it is used only in the `sql/catalyst` packages for restricting number of fields converted to strings from `TreeNode` and expressions of`StructType`. ## How was this patch tested? Added a test to `QueryExecutionSuite` to check that `spark.sql.debug.maxToStringFields` impacts to behavior of `truncatedString`. Closes #23039 from MaxGekk/truncated-string-catalyst. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/util/Utils.scala | 48 ------------------- .../org/apache/spark/util/UtilsSuite.scala | 8 ---- .../sql/catalyst/expressions/Expression.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 4 +- .../spark/sql/catalyst/trees/TreeNode.scala | 10 ++-- .../spark/sql/catalyst/util/package.scala | 37 +++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../apache/spark/sql/types/StructType.scala | 4 +- .../org/apache/spark/sql/util/UtilSuite.scala | 31 ++++++++++++ .../sql/execution/DataSourceScanExec.scala | 5 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 3 +- .../aggregate/HashAggregateExec.scala | 7 +-- .../aggregate/ObjectHashAggregateExec.scala | 8 ++-- .../aggregate/SortAggregateExec.scala | 8 ++-- .../execution/columnar/InMemoryRelation.scala | 5 +- .../datasources/LogicalRelation.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 5 +- .../v2/DataSourceV2StringFormat.scala | 5 +- .../apache/spark/sql/execution/limit.scala | 6 +-- .../streaming/MicroBatchExecution.scala | 7 +-- .../continuous/ContinuousExecution.scala | 7 +-- .../sql/execution/streaming/memory.scala | 4 +- .../sql/execution/QueryExecutionSuite.scala | 26 ++++++++++ 24 files changed, 156 insertions(+), 103 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 743fd5d75b2db..227c9e734f0af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,7 +31,6 @@ import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.TimeUnit.NANOSECONDS -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream import scala.annotation.tailrec @@ -93,53 +92,6 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null - /** - * The performance overhead of creating and logging strings for wide schemas can be large. To - * limit the impact, we bound the number of fields to include by default. This can be overridden - * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv. - */ - val DEFAULT_MAX_TO_STRING_FIELDS = 25 - - private[spark] def maxNumToStringFields = { - if (SparkEnv.get != null) { - SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) - } else { - DEFAULT_MAX_TO_STRING_FIELDS - } - } - - /** Whether we have warned about plan string truncation yet. */ - private val truncationWarningPrinted = new AtomicBoolean(false) - - /** - * Format a sequence with semantics similar to calling .mkString(). Any elements beyond - * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. - * - * @return the trimmed and formatted string. - */ - def truncatedString[T]( - seq: Seq[T], - start: String, - sep: String, - end: String, - maxNumFields: Int = maxNumToStringFields): String = { - if (seq.length > maxNumFields) { - if (truncationWarningPrinted.compareAndSet(false, true)) { - logWarning( - "Truncated the string representation of a plan since it was too large. This " + - "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.") - } - val numFields = math.max(0, maxNumFields - 1) - seq.take(numFields).mkString( - start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) - } else { - seq.mkString(start, sep, end) - } - } - - /** Shorthand for calling truncatedString() without start or end strings. */ - def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5293645cab058..f5e912b50d1ab 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -45,14 +45,6 @@ import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { - test("truncatedString") { - assert(Utils.truncatedString(Nil, "[", ", ", "]", 2) == "[]") - assert(Utils.truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") - } - test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 141fcffcb6fab..d51b11024a09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the basic expression abstract classes in Catalyst. @@ -237,7 +237,7 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString - override def toString: String = prettyName + Utils.truncatedString( + override def toString: String = prettyName + truncatedString( flatArguments.toSeq, "(", ", ", ")") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f09c5ceefed13..07fa17b233a47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler /** @@ -485,7 +485,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) override def output: Seq[Attribute] = child.output override def simpleString: String = { - val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]") s"CTE $cteAliases" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 1027216165005..2e9f9f53e94ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -440,10 +440,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case tn: TreeNode[_] => tn.simpleString :: Nil case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil case iter: Iterable[_] if iter.isEmpty => Nil - case seq: Seq[_] => Utils.truncatedString(seq, "[", ", ", "]") :: Nil - case set: Set[_] => Utils.truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil case array: Array[_] if array.isEmpty => Nil - case array: Array[_] => Utils.truncatedString(array, "[", ", ", "]") :: Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil case null => Nil case None => Nil case Some(null) => Nil @@ -664,7 +664,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => JArray(t.map(parseToJson).toList) case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => - JString(Utils.truncatedString(t, "[", ", ", "]")) + JString(truncatedString(t, "[", ", ", "]")) case t: Seq[_] => JNull case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 0978e92dd4f72..277584b20dcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.catalyst import java.io._ import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -package object util { +package object util extends Logging { /** Silences output to stderr or stdout for the duration of f */ def quietly[A](f: => A): A = { @@ -167,6 +170,38 @@ package object util { builder.toString() } + /** Whether we have warned about plan string truncation yet. */ + private val truncationWarningPrinted = new AtomicBoolean(false) + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond + * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. + * + * @return the trimmed and formatted string. + */ + def truncatedString[T]( + seq: Seq[T], + start: String, + sep: String, + end: String, + maxNumFields: Int = SQLConf.get.maxToStringFields): String = { + if (seq.length > maxNumFields) { + if (truncationWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was too large. This " + + s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") + } + val numFields = math.max(0, maxNumFields - 1) + seq.take(numFields).mkString( + start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 518115dafd011..cc0e9727812db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1594,6 +1594,13 @@ object SQLConf { "WHERE, which does not follow SQL standard.") .booleanConf .createWithDefault(false) + + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") + .doc("Maximum number of fields of sequence-like entries can be converted to strings " + + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + + """ "... N more fields" placeholder.""") + .intConf + .createWithDefault(25) } /** @@ -2009,6 +2016,8 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bef75d5bdb6e..6e8bbde7787a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} import org.apache.spark.util.Utils /** @@ -346,7 +346,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") - Utils.truncatedString(fieldTypes, "struct<", ",", ">") + truncatedString(fieldTypes, "struct<", ",", ">") } override def catalogString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala new file mode 100644 index 0000000000000..9c162026942f6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.truncatedString + +class UtilSuite extends SparkFunSuite { + test("truncatedString") { + assert(truncatedString(Nil, "[", ", ", "]", 2) == "[]") + assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") + assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a9b18ab57237d..77e381ef6e6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,8 +57,8 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + val metadataStr = truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr" } override def verboseString: String = redact(super.verboseString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2962becb64e88..9f67d556af362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -197,6 +197,6 @@ case class RDDScanExec( } override def simpleString: String = { - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${truncatedString(output, "[", ",", "]")}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 87a4ceb91aae6..cfb5e43207b03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} @@ -213,7 +214,7 @@ class QueryExecution( writer.write("== Parsed Logical Plan ==\n") writeOrError(writer)(logical.treeString(_, verbose, addSuffix)) writer.write("\n== Analyzed Logical Plan ==\n") - val analyzedOutput = stringOrError(Utils.truncatedString( + val analyzedOutput = stringOrError(truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")) writer.write(analyzedOutput) writer.write("\n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 08dcdf33fb8f2..4827f838fc514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow @@ -930,9 +931,9 @@ case class HashAggregateExec( testFallbackStartsAt match { case None => - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..7145bb03028d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may @@ -143,9 +143,9 @@ case class ObjectHashAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..d732b905dcdd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * Sort-based aggregate operator. @@ -114,9 +114,9 @@ case class SortAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3b6588587c35a..73eb65f84489c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.LongAccumulator /** @@ -209,5 +210,5 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) override def simpleString: String = - s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" + s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 8d715f6342988..1023572d19e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. @@ -63,7 +63,7 @@ case class LogicalRelation( case _ => // Do nothing. } - override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" + override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation" } object LogicalRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f15014442e3fb..51c385e25bee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,10 +27,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} -import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -159,8 +159,9 @@ private[sql] object JDBCRelation extends Logging { val column = schema.find { f => resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) }.getOrElse { + val maxNumToStringFields = SQLConf.get.maxToStringFields throw new AnalysisException(s"User-defined partition column $columnName not " + - s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + s"found in the JDBC relation: ${schema.simpleString(maxNumToStringFields)}") } column.dataType match { case _: NumericType | DateType | TimestampType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 97e6c6d702acb..e829f621b4ea3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils @@ -72,10 +73,10 @@ trait DataSourceV2StringFormat { }.mkString("[", ",", "]") } - val outputStr = Utils.truncatedString(output, "[", ", ", "]") + val outputStr = truncatedString(output, "[", ", ", "]") val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { + truncatedString(entries.map { case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) }, " (", ", ", ")") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 9bfe1a79fc1e1..90dafcf535914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.util.Utils /** * Take the first `limit` elements and collect them to a single partition. @@ -177,8 +177,8 @@ case class TakeOrderedAndProjectExec( override def outputPartitioning: Partitioning = SinglePartition override def simpleString: String = { - val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") - val outputString = Utils.truncatedString(output, "[", ",", "]") + val orderByString = truncatedString(sortOrder, "[", ",", "]") + val outputString = truncatedString(output, "[", ",", "]") s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2cac86599ef19..5defca391a355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,13 +24,14 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class MicroBatchExecution( sparkSession: SparkSession, @@ -475,8 +476,8 @@ class MicroBatchExecution( case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => assert(output.size == dataPlan.output.size, - s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(dataPlan.output, ",")}") + s"Invalid batch: ${truncatedString(output, ",")} != " + + s"${truncatedString(dataPlan.output, ",")}") val aliases = output.zip(dataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4a7df731da67d..1eab55122e84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} @@ -35,7 +36,7 @@ import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class ContinuousExecution( sparkSession: SparkSession, @@ -164,8 +165,8 @@ class ContinuousExecution( val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, - s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newOutput, ",")}") + s"Invalid reader: ${truncatedString(output, ",")} != " + + s"${truncatedString(newOutput, ",")}") replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52aba21a04..daee089f3871d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -117,7 +117,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" + override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]" override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index a5922d7c825db..0c47a2040f171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -20,9 +20,20 @@ import scala.io.Source import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +case class QueryExecutionTestRecord( + c0: Int, c1: Int, c2: Int, c3: Int, c4: Int, + c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, + c10: Int, c11: Int, c12: Int, c13: Int, c14: Int, + c15: Int, c16: Int, c17: Int, c18: Int, c19: Int, + c20: Int, c21: Int, c22: Int, c23: Int, c24: Int, + c25: Int, c26: Int) + class QueryExecutionSuite extends SharedSQLContext { + import testImplicits._ + def checkDumpedPlans(path: String, expected: Int): Unit = { assert(Source.fromFile(path).getLines.toList .takeWhile(_ != "== Whole Stage Codegen ==") == List( @@ -80,6 +91,21 @@ class QueryExecutionSuite extends SharedSQLContext { assert(exception.getMessage.contains("Illegal character in scheme name")) } + test("limit number of fields by sql config") { + def relationPlans: String = { + val ds = spark.createDataset(Seq(QueryExecutionTestRecord( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) + ds.queryExecution.toString + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "26") { + assert(relationPlans.contains("more fields")) + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "27") { + assert(!relationPlans.contains("more fields")) + } + } + test("toString() exception/error handling") { spark.experimental.extraStrategies = Seq( new SparkStrategy { From 4aa9ccbde7870fb2750712e9e38e6aad740e0770 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Nov 2018 17:03:57 -0600 Subject: [PATCH 093/145] [SPARK-26127][ML] Remove deprecated setters from tree regression and classification models ## What changes were proposed in this pull request? The setter methods are deprecated since 2.1 for the models of regression and classification using trees. The deprecation was stating that the method would have been removed in 3.0. Hence the PR removes the deprecated method. ## How was this patch tested? NA Closes #23093 from mgaido91/SPARK-26127. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 26 ++--- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 27 +++-- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ------------------ project/MimaExcludes.scala | 74 +++++++++++- 8 files changed, 138 insertions(+), 178 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6648e78d8eafa..bcf89766b0873 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -55,27 +55,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,15 +87,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) override protected def train( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62c6bdbdeb285..fab8155add5a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -69,27 +69,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -101,7 +101,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -110,7 +110,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -119,25 +119,25 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) /** @group setParam */ @Since("2.3.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 57132381b6474..05fff8885fbf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -57,27 +57,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -89,31 +89,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c9de85de42fa5..faadc4d7b4ccc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -54,27 +54,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 07f88d8d5f84d..186fa2399af05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -69,27 +68,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -101,7 +100,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -110,7 +109,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -119,21 +118,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: @@ -143,7 +142,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("2.3.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 66d57ad6c4348..7f5e668ca71db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -56,27 +56,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index f1e3836ebe476..c06c68d44ae1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -110,80 +110,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) - /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) - /** @group getParam */ final def getMaxBins: Int = $(maxBins) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setSeed(value: Long): this.type = set(seed, value) - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) - /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -226,13 +170,6 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -273,13 +210,6 @@ private[ml] trait HasVarianceImpurity extends Params { setDefault(impurity -> "variance") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -346,13 +276,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) - /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -406,13 +329,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(featureSubsetStrategy -> "auto") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0 - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -440,13 +356,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - /** @group getParam */ final def getNumTrees: Int = $(numTrees) } @@ -491,13 +400,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @Since("2.4.0") final def getValidationTol: Double = $(validationTol) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) - /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -508,13 +410,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b750535e8a70b..9089c7d9ffc70 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,76 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-26127] Remove deprecated setters from tree regression and classification models + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setNumTrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setNumTrees"), + // [SPARK-26124] Update plugins, including MiMa ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), @@ -50,15 +120,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), From 9b48107f9c84631e0ddaf0f2223296a3cbc16f83 Mon Sep 17 00:00:00 2001 From: Nagaram Prasad Addepally Date: Wed, 21 Nov 2018 15:51:37 -0800 Subject: [PATCH 094/145] [SPARK-25957][K8S] Make building alternate language binding docker images optional ## What changes were proposed in this pull request? bin/docker-image-tool.sh tries to build all docker images (JVM, PySpark and SparkR) by default. But not all spark distributions are built with SparkR and hence this script will fail on such distros. With this change, we make building alternate language binding docker images (PySpark and SparkR) optional. User has to specify dockerfile for those language bindings using -p and -R flags accordingly, to build the binding docker images. ## How was this patch tested? Tested following scenarios. *bin/docker-image-tool.sh -r -t build* --> Builds only JVM docker image (default behavior) *bin/docker-image-tool.sh -r -t -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build* --> Builds both JVM and PySpark docker images *bin/docker-image-tool.sh -r -t -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile -R kubernetes/dockerfiles/spark/bindings/R/Dockerfile build* --> Builds JVM, PySpark and SparkR docker images. Author: Nagaram Prasad Addepally Closes #23053 from ramaddepally/SPARK-25957. --- bin/docker-image-tool.sh | 63 +++++++++++-------- docs/running-on-kubernetes.md | 12 ++++ .../scripts/setup-integration-test-env.sh | 12 +++- 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index aa5d847f4be2f..e51201a77cb5d 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -41,6 +41,18 @@ function image_ref { echo "$image" } +function docker_push { + local image_name="$1" + if [ ! -z $(docker images -q "$(image_ref ${image_name})") ]; then + docker push "$(image_ref ${image_name})" + if [ $? -ne 0 ]; then + error "Failed to push $image_name Docker image." + fi + else + echo "$(image_ref ${image_name}) image not found. Skipping push for this image." + fi +} + function build { local BUILD_ARGS local IMG_PATH @@ -92,8 +104,8 @@ function build { base_img=$(image_ref spark) ) local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} - local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-false} + local RDOCKERFILE=${RDOCKERFILE:-false} docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ @@ -102,33 +114,29 @@ function build { error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-py) \ - -f "$PYDOCKERFILE" . + if [ "${PYDOCKERFILE}" != "false" ]; then + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . + if [ $? -ne 0 ]; then + error "Failed to build PySpark Docker image, please refer to Docker build output for details." + fi + fi + + if [ "${RDOCKERFILE}" != "false" ]; then + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" . if [ $? -ne 0 ]; then - error "Failed to build PySpark Docker image, please refer to Docker build output for details." + error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-r) \ - -f "$RDOCKERFILE" . - if [ $? -ne 0 ]; then - error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi } function push { - docker push "$(image_ref spark)" - if [ $? -ne 0 ]; then - error "Failed to push Spark JVM Docker image." - fi - docker push "$(image_ref spark-py)" - if [ $? -ne 0 ]; then - error "Failed to push PySpark Docker image." - fi - docker push "$(image_ref spark-r)" - if [ $? -ne 0 ]; then - error "Failed to push SparkR Docker image." - fi + docker_push "spark" + docker_push "spark-py" + docker_push "spark-r" } function usage { @@ -143,8 +151,10 @@ Commands: Options: -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. - -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. - -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -p file (Optional) Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + Skips building PySpark docker image if not specified. + -R file (Optional) Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + Skips building SparkR docker image if not specified. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -164,6 +174,9 @@ Examples: - Build image in minikube with tag "testing" $0 -m -t testing build + - Build PySpark docker image + $0 -r docker.io/myrepo -t v2.3.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + - Build and push image with tag "v2.3.0" to docker.io/myrepo $0 -r docker.io/myrepo -t v2.3.0 build $0 -r docker.io/myrepo -t v2.3.0 push diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a7b6fd12a3e5f..a9d448820e700 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -88,6 +88,18 @@ $ ./bin/docker-image-tool.sh -r -t my-tag build $ ./bin/docker-image-tool.sh -r -t my-tag push ``` +By default `bin/docker-image-tool.sh` builds docker image for running JVM jobs. You need to opt-in to build additional +language binding docker images. + +Example usage is +```bash +# To build additional PySpark docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -p ./kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + +# To build additional SparkR docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -R ./kubernetes/dockerfiles/spark/bindings/R/Dockerfile build +``` + ## Cluster Mode To launch Spark Pi in cluster mode, diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index a4a9f5b7da131..36e30d7b2cffb 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -72,10 +72,16 @@ then IMAGE_TAG=$(uuidgen); cd $UNPACKED_SPARK_TGZ + # Build PySpark image + LANGUAGE_BINDING_BUILD_ARGS="-p $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/python/Dockerfile" + + # Build SparkR image + LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/R/Dockerfile" + case $DEPLOY_MODE in cloud) # Build images - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; @@ -89,13 +95,13 @@ then docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 From ce7b57cb5d552ac3df8557a3863792c425005994 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Nov 2018 08:02:23 +0800 Subject: [PATCH 095/145] [SPARK-26106][PYTHON] Prioritizes ML unittests over the doctests in PySpark ## What changes were proposed in this pull request? Arguably, unittests usually takes longer then doctests. We better prioritize unittests over doctests. Other modules are already being prioritized over doctests. Looks ML module was missed at the very first place. ## How was this patch tested? Jenkins tests. Closes #23078 from HyukjinKwon/SPARK-26106. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/run-tests.py b/python/run-tests.py index 9fd1c9b94ac6f..01a6e81264dd6 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -250,7 +250,7 @@ def main(): if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests'] + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: From 38628dd1b8298d2686e5d00de17c461c70db99a8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 22 Nov 2018 09:35:29 +0800 Subject: [PATCH 096/145] [SPARK-25935][SQL] Prevent null rows from JSON parser ## What changes were proposed in this pull request? An input without valid JSON tokens on the root level will be treated as a bad record, and handled according to `mode`. Previously such input was converted to `null`. After the changes, the input is converted to a row with `null`s in the `PERMISSIVE` mode according the schema. This allows to remove a code in the `from_json` function which can produce `null` as result rows. ## How was this patch tested? It was tested by existing test suites. Some of them I have to modify (`JsonSuite` for example) because previously bad input was just silently ignored. For now such input is handled according to specified `mode`. Closes #22938 from MaxGekk/json-nulls. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- docs/sql-migration-guide-upgrade.md | 2 ++ .../expressions/jsonExpressions.scala | 26 ++++++++++++------- .../sql/catalyst/json/JacksonParser.scala | 2 +- .../expressions/JsonExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ------- .../datasources/json/JsonSuite.scala | 12 ++++++--- 7 files changed, 31 insertions(+), 25 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 059c9f3057242..f355a515935c8 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1674,7 +1674,7 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) - expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]]$a, NA) # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 07079d93f25b6..e8f2bcc9adfb4 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -15,6 +15,8 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 52d0677f4022f..543c6c41de58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -550,15 +550,23 @@ case class JsonToStructs( s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } - // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = nullableSchema match { - case _: StructType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null - case _: ArrayType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null - case _: MapType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + private lazy val castRow = nullableSchema match { + case _: StructType => (row: InternalRow) => row + case _: ArrayType => (row: InternalRow) => row.getArray(0) + case _: MapType => (row: InternalRow) => row.getMap(0) + } + + // This converts parsed rows to the desired output by the given schema. + private def convertRow(rows: Iterator[InternalRow]) = { + if (rows.hasNext) { + val result = rows.next() + // JSON's parser produces one record only. + assert(!rows.hasNext) + castRow(result) + } else { + throw new IllegalArgumentException("Expected one row from JSON parser.") + } } val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) @@ -593,7 +601,7 @@ case class JsonToStructs( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - converter(parser.parse(json.asInstanceOf[UTF8String])) + convertRow(parser.parse(json.asInstanceOf[UTF8String])) } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 57c7f2faf3107..773ff5a7a4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -399,7 +399,7 @@ class JacksonParser( // a null first token is equivalent to testing for input.trim.isEmpty // but it works on any token stream and not just strings parser.nextToken() match { - case null => Nil + case null => throw new RuntimeException("Not found any JSON token") case _ => rootConverter.apply(parser) match { case null => throw new RuntimeException("Root converter returned null") case rows => rows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 6ee8c74010d3d..34bd2a99b2b4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -547,7 +547,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), - null + InternalRow(null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dbb0790a4682c..4cc8a45391996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -240,16 +240,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row("1"), Row("2"))) } - test("SPARK-11226 Skip empty line in json file") { - spark.read - .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "").toDS()) - .createOrReplaceTempView("d") - - checkAnswer( - sql("select count(1) from d"), - Seq(Row(3))) - } - test("SPARK-8828 sum should return null if all input values are null") { checkAnswer( sql("select sum(a), avg(a) from allNulls"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 06032ded42a53..9ea9189cdf7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1115,6 +1115,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(null, null, null), Row(null, null, null), Row(null, null, null), + Row(null, null, null), Row("str_a_4", "str_b_4", "str_c_4"), Row(null, null, null)) ) @@ -1136,6 +1137,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.select($"a", $"b", $"c", $"_unparsed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1150,6 +1152,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), Row("{") :: + Row("") :: Row("""{"a":1, b:2}""") :: Row("""{"a":{, b:3}""") :: Row("]") :: Nil @@ -1171,6 +1174,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.selectExpr("a", "b", "c", "_malformed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1813,6 +1817,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .option("compression", "GzIp") .text(path) @@ -1838,6 +1843,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .text(path) @@ -1892,7 +1898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .text(path) val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) - assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.count() === corruptRecordCount + 1) // null row for empty file assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) .add("dummy", StringType)) @@ -1905,7 +1911,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { F.count($"dummy").as("valid"), F.count($"_corrupt_record").as("corrupt"), F.count("*").as("count")) - checkAnswer(counts, Row(1, 4, 6)) + checkAnswer(counts, Row(1, 5, 7)) // null row for empty file } } @@ -2513,7 +2519,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } checkCount(2) - countForMalformedJSON(0, Seq("")) + countForMalformedJSON(1, Seq("")) } test("SPARK-25040: empty strings should be disallowed") { From ab2eafb3cdc7631452650c6cac03a92629255347 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Nov 2018 10:50:01 +0800 Subject: [PATCH 097/145] [SPARK-26085][SQL] Key attribute of non-struct type under typed aggregation should be named as "key" too ## What changes were proposed in this pull request? When doing typed aggregation on a Dataset, for struct key type, the key attribute is named as "key". But for non-struct type, the key attribute is named as "value". This key attribute should also be named as "key" for non-struct type. ## How was this patch tested? Added test. Closes #23054 from viirya/SPARK-26085. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 12 ++++++++++++ .../apache/spark/sql/KeyValueGroupedDataset.scala | 7 ++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index e8f2bcc9adfb4..397ca59d96497 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -20,6 +20,8 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. + + - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. ## Upgrading From Spark SQL 2.3 to 2.4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc0e9727812db..7bcf21595ce5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1595,6 +1595,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE = + buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue") + .internal() + .doc("When set to true, the key attribute resulted from running `Dataset.groupByKey` " + + "for non-struct key type, will be named as `value`, following the behavior of Spark " + + "version 2.4 and earlier.") + .booleanConf + .createWithDefault(false) + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") .doc("Maximum number of fields of sequence-like entries can be converted to strings " + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + @@ -2016,6 +2025,9 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def nameNonStructGroupingKeyAsValue: Boolean = + getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE) + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 7a47242f69381..2d849c65997a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** @@ -459,7 +460,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( columns.map(_.withInputType(vExprEnc, dataAttributes).named) val keyColumn = if (!kExprEnc.isSerializedAsStruct) { assert(groupingAttributes.length == 1) - groupingAttributes.head + if (SQLConf.get.nameNonStructGroupingKeyAsValue) { + groupingAttributes.head + } else { + Alias(groupingAttributes.head, "key")() + } } else { Alias(CreateStruct(groupingAttributes), "key")() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 540fbff6a3a63..baece2ddac7eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1572,6 +1572,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) } + test("SPARK-26085: fix key attribute name for atomic type for typed aggregation") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.groupByKey(x => x).count().schema.head.name == "key") + + // Enable legacy flag to follow previous Spark behavior + withSQLConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE.key -> "true") { + assert(ds.groupByKey(x => x).count().schema.head.name == "value") + } + } + test("SPARK-8288: class with only a companion object constructor") { val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) val ds = data.toDS From 8d54bf79f215378fbd95794591a87604a5eaf7a3 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 22 Nov 2018 10:57:19 +0800 Subject: [PATCH 098/145] [SPARK-26099][SQL] Verification of the corrupt column in from_csv/from_json ## What changes were proposed in this pull request? The corrupt column specified via JSON/CSV option *columnNameOfCorruptRecord* must have the `string` type and be `nullable`. This has been already checked in `DataFrameReader`.`csv`/`json` and in `Json`/`CsvFileFormat` but not in `from_json`/`from_csv`. The PR adds such checks inside functions as well. ## How was this patch tested? Added tests to `Json`/`CsvExpressionSuite` for checking type of the corrupt column. They don't check the `nullable` property because `schema` is forcibly casted to nullable. Closes #23070 from MaxGekk/verify-corrupt-column-csv-json. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/expressions/ExprUtils.scala | 16 ++++++++++++++ .../catalyst/expressions/csvExpressions.scala | 4 ++++ .../expressions/jsonExpressions.scala | 1 + .../expressions/CsvExpressionsSuite.scala | 11 ++++++++++ .../expressions/JsonExpressionsSuite.scala | 11 ++++++++++ .../apache/spark/sql/DataFrameReader.scala | 21 +++---------------- .../datasources/csv/CSVFileFormat.scala | 9 ++------ .../datasources/json/JsonFileFormat.scala | 11 +++------- 8 files changed, 51 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 040b56cc1caea..89e9071324eff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -67,4 +67,20 @@ object ExprUtils { case _ => throw new AnalysisException("Must use a map() function for options") } + + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index aff372b899f86..1e4e1c663c90e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -106,6 +106,10 @@ case class CsvToStructs( throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } + ExprUtils.verifyColumnNameOfCorruptRecord( + nullableSchema, + parsedOptions.columnNameOfCorruptRecord) + val actualSchema = StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 543c6c41de58a..47304d835fdf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -579,6 +579,7 @@ case class JsonToStructs( } val (parserSchema, actualSchema) = nullableSchema match { case s: StructType => + ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) case other => (StructType(StructField("value", other) :: Nil), other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index f5aaaec456153..98c93a4946f4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ @@ -226,4 +227,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(17836)) // number of days from 1970-01-01 } } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + CsvToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("a"), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 34bd2a99b2b4d..9b89a27c23770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -754,4 +755,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with InternalRow(17836)) // number of days from 1970-01-01 } } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + JsonToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("""{"i":"a"}"""), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52df13d39caa7..f08fd64acd9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils @@ -442,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -504,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -765,22 +766,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } - /** - * A convenient function for schema validation in datasources supporting - * `columnNameOfCorruptRecord` as an option. - */ - private def verifyColumnNameOfCorruptRecord( - schema: StructType, - columnNameOfCorruptRecord: String): Unit = { - schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } - } - /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 964b56e706a0b..ff1911d69a6b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 1f7c9d73f19fe..610f0d1619fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { From 15c038497791e7735898356db2464b8732695365 Mon Sep 17 00:00:00 2001 From: Takanobu Asanuma Date: Wed, 21 Nov 2018 23:09:57 -0800 Subject: [PATCH 099/145] [SPARK-26134][CORE] Upgrading Hadoop to 2.7.4 to fix java.version problem ## What changes were proposed in this pull request? When I ran spark-shell on JDK11+28(2018-09-25), It failed with the error below. ``` Exception in thread "main" java.lang.ExceptionInInitializerError at org.apache.hadoop.util.StringUtils.(StringUtils.java:80) at org.apache.hadoop.security.SecurityUtil.getAuthenticationMethod(SecurityUtil.java:611) at org.apache.hadoop.security.UserGroupInformation.initialize(UserGroupInformation.java:273) at org.apache.hadoop.security.UserGroupInformation.ensureInitialized(UserGroupInformation.java:261) at org.apache.hadoop.security.UserGroupInformation.loginUserFromSubject(UserGroupInformation.java:791) at org.apache.hadoop.security.UserGroupInformation.getLoginUser(UserGroupInformation.java:761) at org.apache.hadoop.security.UserGroupInformation.getCurrentUser(UserGroupInformation.java:634) at org.apache.spark.util.Utils$.$anonfun$getCurrentUserName$1(Utils.scala:2427) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.util.Utils$.getCurrentUserName(Utils.scala:2427) at org.apache.spark.SecurityManager.(SecurityManager.scala:79) at org.apache.spark.deploy.SparkSubmit.secMgr$lzycompute$1(SparkSubmit.scala:359) at org.apache.spark.deploy.SparkSubmit.secMgr$1(SparkSubmit.scala:359) at org.apache.spark.deploy.SparkSubmit.$anonfun$prepareSubmitEnvironment$9(SparkSubmit.scala:367) at scala.Option.map(Option.scala:146) at org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment(SparkSubmit.scala:367) at org.apache.spark.deploy.SparkSubmit.submit(SparkSubmit.scala:143) at org.apache.spark.deploy.SparkSubmit.doSubmit(SparkSubmit.scala:86) at org.apache.spark.deploy.SparkSubmit$$anon$2.doSubmit(SparkSubmit.scala:927) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:936) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: java.lang.StringIndexOutOfBoundsException: begin 0, end 3, length 2 at java.base/java.lang.String.checkBoundsBeginEnd(String.java:3319) at java.base/java.lang.String.substring(String.java:1874) at org.apache.hadoop.util.Shell.(Shell.java:52) ``` This is a Hadoop issue that fails to parse some java.version. It has been fixed from Hadoop-2.7.4(see [HADOOP-14586](https://issues.apache.org/jira/browse/HADOOP-14586)). Note, Hadoop-2.7.5 or upper have another problem with Spark ([SPARK-25330](https://issues.apache.org/jira/browse/SPARK-25330)). So upgrading to 2.7.4 would be fine for now. ## How was this patch tested? Existing tests. Closes #23101 from tasanuma/SPARK-26134. Authored-by: Takanobu Asanuma Signed-off-by: Dongjoon Hyun --- assembly/README | 2 +- dev/deps/spark-deps-hadoop-2.7 | 31 ++++++++++--------- pom.xml | 2 +- .../kubernetes/integration-tests/README.md | 2 +- .../hive/client/IsolatedClientLoader.scala | 2 +- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/assembly/README b/assembly/README index d5dafab477410..1fd6d8858348c 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.4 diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index c2f5755ca9925..ec7c304c9e36b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -64,21 +64,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.3.jar -hadoop-auth-2.7.3.jar -hadoop-client-2.7.3.jar -hadoop-common-2.7.3.jar -hadoop-hdfs-2.7.3.jar -hadoop-mapreduce-client-app-2.7.3.jar -hadoop-mapreduce-client-common-2.7.3.jar -hadoop-mapreduce-client-core-2.7.3.jar -hadoop-mapreduce-client-jobclient-2.7.3.jar -hadoop-mapreduce-client-shuffle-2.7.3.jar -hadoop-yarn-api-2.7.3.jar -hadoop-yarn-client-2.7.3.jar -hadoop-yarn-common-2.7.3.jar -hadoop-yarn-server-common-2.7.3.jar -hadoop-yarn-server-web-proxy-2.7.3.jar +hadoop-annotations-2.7.4.jar +hadoop-auth-2.7.4.jar +hadoop-client-2.7.4.jar +hadoop-common-2.7.4.jar +hadoop-hdfs-2.7.4.jar +hadoop-mapreduce-client-app-2.7.4.jar +hadoop-mapreduce-client-common-2.7.4.jar +hadoop-mapreduce-client-core-2.7.4.jar +hadoop-mapreduce-client-jobclient-2.7.4.jar +hadoop-mapreduce-client-shuffle-2.7.4.jar +hadoop-yarn-api-2.7.4.jar +hadoop-yarn-client-2.7.4.jar +hadoop-yarn-common-2.7.4.jar +hadoop-yarn-server-common-2.7.4.jar +hadoop-yarn-server-web-proxy-2.7.4.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -117,6 +117,7 @@ jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jetty-6.1.26.jar +jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.6.jar joda-time-2.9.3.jar diff --git a/pom.xml b/pom.xml index 08a29d2d52310..93075e9b06a68 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ spark 1.7.16 1.2.17 - 2.7.3 + 2.7.4 2.5.0 ${hadoop.version} 3.4.6 diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 64f8e77597eba..73fc0581d64f5 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -107,7 +107,7 @@ properties to Maven. For example: mvn integration-test -am -pl :spark-kubernetes-integration-tests_2.11 \ -Pkubernetes -Pkubernetes-integration-tests \ - -Phadoop-2.7 -Dhadoop.version=2.7.3 \ + -Phadoop-2.7 -Dhadoop.version=2.7.4 \ -Dspark.kubernetes.test.sparkTgz=spark-3.0.0-SNAPSHOT-bin-example.tgz \ -Dspark.kubernetes.test.imageTag=sometag \ -Dspark.kubernetes.test.imageRepo=docker.io/somerepo \ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f56ca8cb08553..ca98c30add168 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -65,7 +65,7 @@ private[hive] object IsolatedClientLoader extends Logging { case e: RuntimeException if e.getMessage.contains("hadoop") => // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. - val fallbackVersion = "2.7.3" + val fallbackVersion = "2.7.4" logWarning(s"Failed to resolve Hadoop artifacts for the version $hadoopVersion. We " + s"will change the hadoop version from $hadoopVersion to $fallbackVersion and try " + "again. Hadoop classes will not be shared between Spark and Hive metastore client. " + From ab00533490953164cb2360bf2b9adc2c9fa962db Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 22 Nov 2018 02:27:06 -0800 Subject: [PATCH 100/145] [SPARK-26129][SQL] edge behavior for QueryPlanningTracker.topRulesByTime - followup patch ## What changes were proposed in this pull request? This is an addendum patch for SPARK-26129 that defines the edge case behavior for QueryPlanningTracker.topRulesByTime. ## How was this patch tested? Added unit tests for each behavior. Closes #23110 from rxin/SPARK-26129-1. Authored-by: Reynold Xin Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/QueryPlanningTracker.scala | 17 ++++++++++++----- .../catalyst/QueryPlanningTrackerSuite.scala | 9 ++++++++- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala index 420f2a1f20997..244081cd160b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -116,12 +116,19 @@ class QueryPlanningTracker { def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap - /** Returns the top k most expensive rules (as measured by time). */ + /** + * Returns the top k most expensive rules (as measured by time). If k is larger than the rules + * seen so far, return all the rules. If there is no rule seen so far or k <= 0, return empty seq. + */ def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = { - val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) - val q = new BoundedPriorityQueue(k)(orderingByTime) - rulesMap.asScala.foreach(q.+=) - q.toSeq.sortBy(r => -r._2.totalTimeNs) + if (k <= 0) { + Seq.empty + } else { + val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) + val q = new BoundedPriorityQueue(k)(orderingByTime) + rulesMap.asScala.foreach(q.+=) + q.toSeq.sortBy(r => -r._2.totalTimeNs) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala index f42c262dfbdd8..120b284a77854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -62,17 +62,24 @@ class QueryPlanningTrackerSuite extends SparkFunSuite { test("topRulesByTime") { val t = new QueryPlanningTracker + + // Return empty seq when k = 0 + assert(t.topRulesByTime(0) == Seq.empty) + assert(t.topRulesByTime(1) == Seq.empty) + t.recordRuleInvocation("r2", 2, effective = true) t.recordRuleInvocation("r4", 4, effective = true) t.recordRuleInvocation("r1", 1, effective = false) t.recordRuleInvocation("r3", 3, effective = false) + // k <= total size + assert(t.topRulesByTime(0) == Seq.empty) val top = t.topRulesByTime(2) assert(top.size == 2) assert(top(0)._1 == "r4") assert(top(1)._1 == "r3") - // Don't crash when k > total size + // k > total size assert(t.topRulesByTime(10).size == 4) } } From aeda76e2b74ef07b2814770d68cf145cdbb0197c Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Thu, 22 Nov 2018 15:43:04 -0600 Subject: [PATCH 101/145] [GRAPHX] Remove unused variables left over by previous refactoring. ## What changes were proposed in this pull request? Some variables were previously used for indexing the routing table's backing array, but that indexing now happens elsewhere, and so the variables aren't needed. ## How was this patch tested? Unit tests. (This contribution is my original work and I license the work to Spark under its open source license.) Closes #23112 from huonw/remove-unused-variables. Authored-by: Huon Wilson Signed-off-by: Sean Owen --- .../apache/spark/graphx/impl/ShippableVertexPartition.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index a4e293d74a012..184b96426fa9b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -117,13 +117,11 @@ class ShippableVertexPartition[VD: ClassTag]( val initialSize = if (shipSrc && shipDst) routingTable.partitionSize(pid) else 64 val vids = new PrimitiveVector[VertexId](initialSize) val attrs = new PrimitiveVector[VD](initialSize) - var i = 0 routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid => if (isDefined(vid)) { vids += vid attrs += this(vid) } - i += 1 } (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) } @@ -137,12 +135,10 @@ class ShippableVertexPartition[VD: ClassTag]( def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = { Iterator.tabulate(routingTable.numEdgePartitions) { pid => val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid)) - var i = 0 routingTable.foreachWithinEdgePartition(pid, true, true) { vid => if (isDefined(vid)) { vids += vid } - i += 1 } (pid, vids.trim().array) } From dd8c179c28c5df20210b70a69d93d866ccaca4cc Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Nov 2018 15:45:25 -0600 Subject: [PATCH 102/145] [SPARK-25867][ML] Remove KMeans computeCost ## What changes were proposed in this pull request? The PR removes the deprecated method `computeCost` of `KMeans`. ## How was this patch tested? NA Closes #22875 from mgaido91/SPARK-25867. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../org/apache/spark/ml/clustering/KMeans.scala | 16 ---------------- .../apache/spark/ml/clustering/KMeansSuite.scala | 12 +++++------- project/MimaExcludes.scala | 3 +++ python/pyspark/ml/clustering.py | 16 ---------------- 4 files changed, 8 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 498310d6644e1..919496aa1a840 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -143,22 +143,6 @@ class KMeansModel private[ml] ( @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) - /** - * Return the K-means cost (sum of squared distances of points to their nearest center) for this - * model on the given data. - * - * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator - * instead. You can also get the cost on the training dataset in the summary. - */ - @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + - "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") - @Since("2.0.0") - def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) - val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - parentModel.computeCost(data) - } - /** * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ccbceab53bb66..4f47d91f0d0d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -117,7 +117,6 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusters === Set(0, 1, 2, 3, 4)) } - assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) // Check validity of model summary @@ -132,7 +131,6 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } assert(summary.cluster.columns === Array(predictionColName)) assert(summary.trainingCost < 0.1) - assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) @@ -201,15 +199,15 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } test("KMean with Array input") { - def trainAndComputeCost(dataset: Dataset[_]): Double = { + def trainAndGetCost(dataset: Dataset[_]): Double = { val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) - model.computeCost(dataset) + model.summary.trainingCost } val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) - val trueCost = trainAndComputeCost(newDataset) - val doubleArrayCost = trainAndComputeCost(newDatasetD) - val floatArrayCost = trainAndComputeCost(newDatasetF) + val trueCost = trainAndGetCost(newDataset) + val doubleArrayCost = trainAndGetCost(newDatasetD) + val floatArrayCost = trainAndGetCost(newDatasetF) // checking the cost is fine enough as a sanity check assert(trueCost ~== doubleArrayCost absTol 1e-6) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9089c7d9ffc70..333adb0c84025 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25867] Remove KMeans computeCost + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), + // [SPARK-26127] Remove deprecated setters from tree regression and classification models ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index aaeeeb82d3d86..d0b507ec5dad4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -335,20 +335,6 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] - @since("2.0.0") - def computeCost(self, dataset): - """ - Return the K-means cost (sum of squared distances of points to their nearest center) - for this model on the given data. - - ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. - You can also get the cost on the training dataset in the summary. - """ - warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " - "instead. You can also get the cost on the training dataset in the summary.", - DeprecationWarning) - return self._call_java("computeCost", dataset) - @property @since("2.1.0") def hasSummary(self): @@ -387,8 +373,6 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol >>> centers = model.clusterCenters() >>> len(centers) 2 - >>> model.computeCost(df) - 2.0 >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction From d81d95a7e8a621e42c9c61305c32df72b6e868be Mon Sep 17 00:00:00 2001 From: oraviv Date: Thu, 22 Nov 2018 15:48:01 -0600 Subject: [PATCH 103/145] [SPARK-19368][MLLIB] BlockMatrix.toIndexedRowMatrix() optimization for sparse matrices ## What changes were proposed in this pull request? Optimization [SPARK-12869] was made for dense matrices but caused great performance issue for sparse matrices because manipulating them is very inefficient. When manipulating sparse matrices in Breeze we better use VectorBuilder. ## How was this patch tested? checked it against a use case that we have that after moving to Spark 2 took 6.5 hours instead of 20 mins. After the change it is back to 20 mins again. Closes #16732 from uzadude/SparseVector_optimization. Authored-by: oraviv Signed-off-by: Sean Owen --- .../linalg/distributed/BlockMatrix.scala | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 7caacd13b3459..e58860fea97d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.linalg.distributed +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM} import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV} - import org.apache.spark.{Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging @@ -28,6 +27,7 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel + /** * A grid partitioner, which uses a regular grid to partition coordinates. * @@ -273,24 +273,37 @@ class BlockMatrix @Since("1.3.0") ( require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).") val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) => - mat.rowIter.zipWithIndex.map { + mat.rowIter.zipWithIndex.filter(_._1.size > 0).map { case (vector, rowIdx) => - blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector.asBreeze)) + blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector)) } }.groupByKey().map { case (rowIdx, vectors) => - val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble - - val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz - BSV.zeros[Double](cols) - } else { - BDV.zeros[Double](cols) - } + val numberNonZero = vectors.map(_._2.numActives).sum + val numberNonZeroPerRow = numberNonZero.toDouble / cols.toDouble + + val wholeVector = + if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz + val arrBufferIndices = new ArrayBuffer[Int](numberNonZero) + val arrBufferValues = new ArrayBuffer[Double](numberNonZero) + + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + vec.foreachActive { case (colIdx: Int, value: Double) => + arrBufferIndices += offset + colIdx + arrBufferValues += value + } + } + Vectors.sparse(cols, arrBufferIndices.toArray, arrBufferValues.toArray) + } else { + val wholeVectorBuf = BDV.zeros[Double](cols) + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + wholeVectorBuf(offset until Math.min(cols, offset + colsPerBlock)) := vec.asBreeze + } + Vectors.fromBreeze(wholeVectorBuf) + } - vectors.foreach { case (blockColIdx: Int, vec: BV[_]) => - val offset = colsPerBlock * blockColIdx - wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec - } - new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector)) + IndexedRow(rowIdx, wholeVector) } new IndexedRowMatrix(rows) } From 1d766f0e222c24e8e8cad68e664e83f4f71f7541 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 22 Nov 2018 14:49:41 -0800 Subject: [PATCH 104/145] [SPARK-26144][BUILD] `build/mvn` should detect `scala.version` based on `scala.binary.version` ## What changes were proposed in this pull request? Currently, `build/mvn` downloads and uses **Scala 2.12.7** in `Scala-2.11` Jenkins job. The root cause is `build/mvn` got the first match from `pom.xml` blindly. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.11/6/consoleFull ``` exec: curl -s -L https://downloads.lightbend.com/zinc/0.3.15/zinc-0.3.15.tgz exec: curl -s -L https://downloads.lightbend.com/scala/2.12.7/scala-2.12.7.tgz exec: curl -s -L https://www.apache.org/dyn/closer.lua?action=download&filename=/maven/maven-3/3.5.4/binaries/apache-maven-3.5.4-bin.tar.gz ``` ## How was this patch tested? Manual. ``` $ build/mvn clean exec: curl --progress-bar -L https://downloads.lightbend.com/scala/2.12.7/scala-2.12.7.tgz ... $ git clean -fdx $ dev/change-scala-version.sh 2.11 $ build/mvn clean exec: curl --progress-bar -L https://downloads.lightbend.com/scala/2.11.12/scala-2.11.12.tgz ``` Closes #23118 from dongjoon-hyun/SPARK-26144. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- build/mvn | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index 3816993b4e5c8..4cb10e0d03fa4 100755 --- a/build/mvn +++ b/build/mvn @@ -116,7 +116,8 @@ install_zinc() { # the build/ folder install_scala() { # determine the Scala version used in Spark - local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_binary_version=`grep "scala.binary.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | grep ${scala_binary_version} | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} From 76aae7f1fd512f150ffcdb618107b12e1e97fe43 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Nov 2018 14:54:00 -0800 Subject: [PATCH 105/145] [SPARK-24553][UI][FOLLOWUP] Fix unnecessary UI redirect ## What changes were proposed in this pull request? This PR is a follow-up PR of #21600 to fix the unnecessary UI redirect. ## How was this patch tested? Local verification Closes #23116 from jerryshao/SPARK-24553. Authored-by: jerryshao Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../scala/org/apache/spark/ui/storage/StoragePage.scala | 2 +- .../org/apache/spark/ui/storage/StoragePageSuite.scala | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b9abd39b4705d..766efc15e26ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -368,7 +368,7 @@ private[ui] class StagePagedTable( {if (cachedRddInfos.nonEmpty) { Text("RDD: ") ++ cachedRddInfos.map { i => - {i.name} + {i.name} } }}
    {s.details}
    diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 3eb546e336e99..2488197814ffd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -78,7 +78,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends {rdd.id} - {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index cdc7f541b9552..06f01a60868f9 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -81,19 +81,19 @@ class StoragePageSuite extends SparkFunSuite { Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=1")) + Some("http://localhost:4040/storage/rdd/?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=2")) + Some("http://localhost:4040/storage/rdd/?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=3")) + Some("http://localhost:4040/storage/rdd/?id=3")) } test("empty rddTable") { From 0ec7b99ea2b638453ed38bb092905bee4f907fe5 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Fri, 23 Nov 2018 08:55:00 +0800 Subject: [PATCH 106/145] [SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's behavior. In addition current behavior with codegen is unpredictable (see example in JIRA ticket). ## What changes were proposed in this pull request? In Platform.putDouble/Float() checking if the value is -0.0, and if so replacing with 0.0. This is used by UnsafeRow so it won't have -0.0 values. ## How was this patch tested? Added tests Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero. Authored-by: Alon Doron Signed-off-by: Wenchen Fan --- .../java/org/apache/spark/unsafe/Platform.java | 10 ++++++++++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 14 ++++++++++++++ .../spark/sql/catalyst/expressions/UnsafeRow.java | 6 ------ .../catalyst/expressions/codegen/UnsafeWriter.java | 6 ------ .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 ++++++++++++++ .../scala/org/apache/spark/sql/QueryTest.scala | 5 ++++- 6 files changed, 42 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 076b693f81c88..4563efcfcf474 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -174,6 +174,11 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -182,6 +187,11 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..ab34324eb54cc 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,18 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); + float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); + + Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); + Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..9bf9452855f5f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 2781655002000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) { } protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d9ba6e2ce5120..ff64edcd07f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { + val colName = "i" + val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() + val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() + + assert(doubles.length == 1) + assert(floats.length == 1) + // using compare since 0.0 == -0.0 is true + assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) + assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) + assert(doubles(0).getLong(1) == 3) + assert(floats(0).getLong(1) == 3) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index baca9c1cfb9a0..8ba67239fb907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -289,7 +289,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case d: java.math.BigDecimal => BigDecimal(d) + case bd: java.math.BigDecimal => BigDecimal(bd) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,6 +303,9 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) + // spark treats -0.0 as 0.0 + case d: Double if d == -0.0d => 0.0d + case f: Float if f == -0.0f => 0.0f case o => o }) } From 1d3dd58d21400b5652b75af7e7e53aad85a31528 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 22 Nov 2018 22:45:08 -0800 Subject: [PATCH 107/145] [SPARK-25954][SS][FOLLOWUP][TEST-MAVEN] Add Zookeeper 3.4.7 test dependency to Kafka modules ## What changes were proposed in this pull request? This is a followup of #23099 . After upgrading to Kafka 2.1.0, maven test fails due to Zookeeper test dependency while sbt test succeeds. - [sbt test on master branch](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/5203/) - [maven test on master branch](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/5653/) The root cause is that the embedded Kafka server is using [Zookeepr 3.4.7 API](https://zookeeper.apache.org/doc/r3.4.7/api/org/apache/zookeeper/AsyncCallback.MultiCallback.html ) while Apache Spark provides Zookeeper 3.4.6. This PR adds a test dependency. ``` KafkaMicroBatchV2SourceSuite: *** RUN ABORTED *** ... org.apache.spark.sql.kafka010.KafkaTestUtils.setupEmbeddedKafkaServer(KafkaTestUtils.scala:123) ... Cause: java.lang.ClassNotFoundException: org.apache.zookeeper.AsyncCallback$MultiCallback at java.net.URLClassLoader.findClass(URLClassLoader.java:381) at java.lang.ClassLoader.loadClass(ClassLoader.java:424) at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:331) at java.lang.ClassLoader.loadClass(ClassLoader.java:357) at kafka.zk.KafkaZkClient$.apply(KafkaZkClient.scala:1693) at kafka.server.KafkaServer.createZkClient$1(KafkaServer.scala:348) at kafka.server.KafkaServer.initZkClient(KafkaServer.scala:372) at kafka.server.KafkaServer.startup(KafkaServer.scala:202) at org.apache.spark.sql.kafka010.KafkaTestUtils.$anonfun$setupEmbeddedKafkaServer$2(KafkaTestUtils.scala:120) at org.apache.spark.sql.kafka010.KafkaTestUtils.$anonfun$setupEmbeddedKafkaServer$2$adapted(KafkaTestUtils.scala:116) ... ``` ## How was this patch tested? Pass the maven Jenkins test. Closes #23119 from dongjoon-hyun/SPARK-25954-2. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 7 +++++++ external/kafka-0-10/pom.xml | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index d97e8cf18605e..1af407167597b 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -89,6 +89,13 @@
    + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index cfc45559d8e34..ea18b7e035915 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -74,6 +74,13 @@ + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple From 92fc0a8f9619a8e7f8382d6a5c288aeceb03a472 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 23 Nov 2018 06:18:44 -0600 Subject: [PATCH 108/145] [SPARK-26069][TESTS][FOLLOWUP] Add another possible error message ## What changes were proposed in this pull request? `org.apache.spark.network.RpcIntegrationSuite.sendRpcWithStreamFailures` is still flaky and here is error message: ``` sbt.ForkMain$ForkError: java.lang.AssertionError: Got a non-empty set [Failed to send RPC RPC 8249697863992194475 to /172.17.0.2:41177: java.io.IOException: Broken pipe] at org.junit.Assert.fail(Assert.java:88) at org.junit.Assert.assertTrue(Assert.java:41) at org.apache.spark.network.RpcIntegrationSuite.assertErrorAndClosed(RpcIntegrationSuite.java:389) at org.apache.spark.network.RpcIntegrationSuite.sendRpcWithStreamFailures(RpcIntegrationSuite.java:347) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.internal.runners.statements.RunBefores.evaluate(RunBefores.java:26) at org.junit.internal.runners.statements.RunAfters.evaluate(RunAfters.java:27) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runners.Suite.runChild(Suite.java:128) at org.junit.runners.Suite.runChild(Suite.java:27) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runner.JUnitCore.run(JUnitCore.java:137) at org.junit.runner.JUnitCore.run(JUnitCore.java:115) at com.novocode.junit.JUnitRunner$1.execute(JUnitRunner.java:132) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` This happened when the second RPC message was being sent but the connection was closed at the same time. ## How was this patch tested? Jenkins Closes #23109 from zsxwing/SPARK-26069-2. Authored-by: Shixiong Zhu Signed-off-by: Sean Owen --- .../spark/network/RpcIntegrationSuite.java | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 45f4a1808562d..1c0aa4da27ff9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -371,18 +371,20 @@ private void assertErrorsContain(Set errors, Set contains) { private void assertErrorAndClosed(RpcResult result, String expectedError) { assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); - // we expect 1 additional error, which should contain one of the follow messages: - // - "closed" - // - "Connection reset" - // - "java.nio.channels.ClosedChannelException" Set errors = result.errorMessages; assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + errors, 2, errors.size()); + // We expect 1 additional error due to closed connection and here are possible keywords in the + // error message. + Set possibleClosedErrors = Sets.newHashSet( + "closed", + "Connection reset", + "java.nio.channels.ClosedChannelException", + "java.io.IOException: Broken pipe" + ); Set containsAndClosed = Sets.newHashSet(expectedError); - containsAndClosed.add("closed"); - containsAndClosed.add("Connection reset"); - containsAndClosed.add("java.nio.channels.ClosedChannelException"); + containsAndClosed.addAll(possibleClosedErrors); Pair, Set> r = checkErrorsContain(errors, containsAndClosed); @@ -390,7 +392,9 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) { Set errorsNotFound = r.getRight(); assertEquals( - "The size of " + errorsNotFound.toString() + " was not 2", 2, errorsNotFound.size()); + "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1), + possibleClosedErrors.size() - 1, + errorsNotFound.size()); for (String err: errorsNotFound) { assertTrue("Found a wrong error " + err, containsAndClosed.contains(err)); } From 466d011d3515723653e41d8b1d0b6150b9945f52 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 23 Nov 2018 21:12:25 +0800 Subject: [PATCH 109/145] [SPARK-26117][CORE][SQL] use SparkOutOfMemoryError instead of OutOfMemoryError when catch exception ## What changes were proposed in this pull request? the pr #20014 which introduced `SparkOutOfMemoryError` to avoid killing the entire executor when an `OutOfMemoryError `is thrown. so apply for memory using `MemoryConsumer. allocatePage `when catch exception, use `SparkOutOfMemoryError `instead of `OutOfMemoryError` ## How was this patch tested? N / A Closes #23084 from heary-cao/SparkOutOfMemoryError. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../java/org/apache/spark/memory/MemoryConsumer.java | 10 +++++----- .../org/apache/spark/unsafe/map/BytesToBytesMap.java | 5 +++-- .../unsafe/sort/UnsafeExternalSorterSuite.java | 7 ++++--- .../unsafe/sort/UnsafeInMemorySorterSuite.java | 5 +++-- .../catalyst/expressions/RowBasedKeyValueBatch.java | 3 ++- .../apache/spark/sql/execution/python/RowQueue.scala | 4 ++-- 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 8371deca7311d..4bfd2d358f36f 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -83,10 +83,10 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark - * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this - * `LongArray` is too large to fit in a single page. The caller side should take care of these - * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError` + * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException` + * if this `LongArray` is too large to fit in a single page. The caller side should take care of + * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions. * * @throws SparkOutOfMemoryError * @throws TooLargePageException @@ -111,7 +111,7 @@ public void freeArray(LongArray array) { /** * Allocate a memory block with at least `required` bytes. * - * @throws OutOfMemoryError + * @throws SparkOutOfMemoryError */ protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9b6cbab38cbcc..a4e88598f7607 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; @@ -741,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) { try { growAndRehash(); - } catch (OutOfMemoryError oom) { + } catch (SparkOutOfMemoryError oom) { canGrowArray = false; } } @@ -757,7 +758,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff private boolean acquireNewPage(long required) { try { currentPage = allocatePage(required); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { return false; } dataPages.add(currentPage); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 411cd5cb57331..d1b29d90ad913 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -38,6 +38,7 @@ import org.apache.spark.executor.TaskMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; @@ -534,10 +535,10 @@ public void testOOMDuringSpill() throws Exception { insertNumber(sorter, 1024); fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); } - // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) - catch (OutOfMemoryError oom){ + // we expect an SparkOutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (SparkOutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in " + + assertThat("expected SparkOutOfMemoryError in " + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", oomStackTrace, Matchers.containsString( diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 85ffdca436e14..b0d485f0c953f 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -27,6 +27,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -178,8 +179,8 @@ public int compare( testMemoryManager.markExecutionAsOutOfMemoryOnce(); try { sorter.reset(); - fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); - } catch (OutOfMemoryError oom) { + fail("expected SparkOutOfMemoryError but it seems operation surprisingly succeeded"); + } catch (SparkOutOfMemoryError oom) { // as expected } // [SPARK-21907] this failed on NPE at diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 460513816dfd9..6344cf18c11b8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -126,7 +127,7 @@ public final void close() { private boolean acquirePage(long requiredSize) { try { page = allocatePage(requiredSize); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { logger.warn("Failed to allocate page ({} bytes).", requiredSize); return false; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index d2820ff335ecf..eb12641f548ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -23,7 +23,7 @@ import com.google.common.io.Closeables import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -226,7 +226,7 @@ private[python] case class HybridRowQueue( val page = try { allocatePage(required) } catch { - case _: OutOfMemoryError => + case _: SparkOutOfMemoryError => null } val buffer = if (page != null) { From 8e8d1177e623d5f995fb9ba1d9574675e1e70d56 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 24 Nov 2018 00:50:20 +0900 Subject: [PATCH 110/145] [SPARK-26108][SQL] Support custom lineSep in CSV datasource ## What changes were proposed in this pull request? In the PR, I propose new options for CSV datasource - `lineSep` similar to Text and JSON datasource. The option allows to specify custom line separator of maximum length of 2 characters (because of a restriction in `uniVocity` parser). New option can be used in reading and writing CSV files. ## How was this patch tested? Added a few tests with custom `lineSep` for enabled/disabled `multiLine` in read as well as tests in write. Also I added roundtrip tests. Closes #23080 from MaxGekk/csv-line-sep. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/readwriter.py | 13 ++- python/pyspark/sql/streaming.py | 7 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 23 +++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/csv/CSVDataSource.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../execution/datasources/csv/CSVSuite.scala | 110 +++++++++++++++++- 8 files changed, 151 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 726de4a965418..1d2dd4d808930 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -453,6 +453,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -472,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -868,7 +871,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -922,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default UTF-8 charset will be used. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, ``""``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. Maximum length is 1 character. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -932,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding, emptyValue=emptyValue) + encoding=encoding, emptyValue=emptyValue, lineSep=lineSep) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58ca7b83e5b2b..d92b0d5677e25 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -576,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None, locale=None): + enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -675,6 +675,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -692,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue, locale=locale) + emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 6bb50b42a369c..94bdb72d675d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -192,6 +192,20 @@ class CSVOptions( */ val emptyValueInWrite = emptyValue.getOrElse("\"\"") + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + require(sep.length == 1, "'lineSep' can contain only 1 character.") + sep + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(charset) + } + val lineSeparatorInWrite: Option[String] = lineSeparator + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -200,6 +214,8 @@ class CSVOptions( format.setQuoteEscape(escape) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + lineSeparatorInWrite.foreach(format.setLineSeparator) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -216,8 +232,10 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + lineSeparator.foreach(format.setLineSeparator) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) @@ -227,7 +245,10 @@ class CSVOptions( settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - settings.setLineSeparatorDetectionEnabled(multiLine == true) + settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine) + lineSeparatorInRead.foreach { _ => + settings.setNormalizeLineEndingsWithinQuotes(!multiLine) + } settings } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f08fd64acd9a1..da88598eed061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -609,6 +609,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 29d479f542115..5a807d3d4b93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * whitespaces from values being written should be skipped. *
  • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not * trailing whitespaces from values being written should be skipped.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing. + * Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 554baaf1a9b3b..b35b8851918b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource { headerChecker: CSVHeaderChecker, requiredSchema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index e4250145a1ae2..c8e3e1c191044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -377,6 +377,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e29cd2aa7c4e6..c275d63d32cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -33,7 +33,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -1880,4 +1880,110 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("""Support line separator - default value \r, \r\n and \n""") { + val data = "\"a\",1\r\"c\",2\r\n\"d\",3\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("inferSchema", true).csv(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("_c0", StringType) :: StructField("_c1", IntegerType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + def testLineSeparator(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"Support line separator in ${encoding} #${id}") { + // Read + val data = + s""""a",1$lineSep + |c,2$lineSep" + |d",3""".stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(encoding)) + val schema = StructType(StructField("_c0", StringType) + :: StructField("_c1", LongType) :: Nil) + + val expected = Seq(("a", 1), ("\nc", 2), ("\nd", 3)) + .toDF("_c0", "_c1") + Seq(false, true).foreach { multiLine => + val reader = spark + .read + .option("lineSep", lineSep) + .option("multiLine", multiLine) + .option("encoding", encoding) + val df = if (inferSchema) { + reader.option("inferSchema", true).csv(path.getAbsolutePath) + } else { + reader.schema(schema).csv(path.getAbsolutePath) + } + checkAnswer(df, expected) + } + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), encoding) + assert( + readBack === s"a${lineSep}b${lineSep}c${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val readBack = spark + .read + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, ":", "ISO-8859-1", true), + (3, "!", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "у", "CP1251", true), + (8, "\r", "UTF-16LE", true), + (9, "\u000d", "UTF-32BE", false), + (10, "=", "US-ASCII", false), + (11, "$", "utf-32le", true) + ).foreach { case (testNum, sep, encoding, inferSchema) => + testLineSeparator(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("lineSep restrictions") { + val errMsg1 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg1.contains("'lineSep' cannot be an empty string")) + + val errMsg2 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "123").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg2.contains("'lineSep' can contain only 1 character")) + } } From ecb785f4e471ce3add66c67d0d8152dd237dbfaf Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 23 Nov 2018 21:08:06 +0100 Subject: [PATCH 111/145] [SPARK-26038] Decimal toScalaBigInt/toJavaBigInteger for decimals not fitting in long ## What changes were proposed in this pull request? Fix Decimal `toScalaBigInt` and `toJavaBigInteger` used to only work for decimals not fitting long. ## How was this patch tested? Added test to DecimalSuite. Closes #23022 from juliuszsompolski/SPARK-26038. Authored-by: Juliusz Sompolski Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/types/Decimal.scala | 16 ++++++++++++++-- .../apache/spark/sql/types/DecimalSuite.scala | 11 +++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a3a844670e0c6..0192059a3a39f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -185,9 +185,21 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toScalaBigInt: BigInt = BigInt(toLong) + def toScalaBigInt: BigInt = { + if (decimalVal.ne(null)) { + decimalVal.toBigInt() + } else { + BigInt(toLong) + } + } - def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) + def toJavaBigInteger: java.math.BigInteger = { + if (decimalVal.ne(null)) { + decimalVal.underlying().toBigInteger() + } else { + java.math.BigInteger.valueOf(toLong) + } + } def toUnscaledLong: Long = { if (decimalVal.ne(null)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44ca..8abd7625c21aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -228,4 +228,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val decimal = Decimal.apply(bigInt) assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") } + + test("SPARK-26038: toScalaBigInt/toJavaBigInteger") { + // not fitting long + val decimal = Decimal("1234568790123456789012348790.1234879012345678901234568790") + assert(decimal.toScalaBigInt == scala.math.BigInt("1234568790123456789012348790")) + assert(decimal.toJavaBigInteger == new java.math.BigInteger("1234568790123456789012348790")) + // fitting long + val decimalLong = Decimal(123456789123456789L, 18, 9) + assert(decimalLong.toScalaBigInt == scala.math.BigInt("123456789")) + assert(decimalLong.toJavaBigInteger == new java.math.BigInteger("123456789")) + } } From de84899204f3428f3d1d688b277dc06b021d860a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 23 Nov 2018 14:14:21 -0800 Subject: [PATCH 112/145] [SPARK-26140] Enable custom metrics implementation in shuffle reader ## What changes were proposed in this pull request? This patch defines an internal Spark interface for reporting shuffle metrics and uses that in shuffle reader. Before this patch, shuffle metrics is tied to a specific implementation (using a thread local temporary data structure and accumulators). After this patch, callers that define their own shuffle RDDs can create a custom metrics implementation. With this patch, we would be able to create a better metrics for the SQL layer, e.g. reporting shuffle metrics in the SQL UI, for each exchange operator. Note that I'm separating read side and write side implementations, as they are very different, to simplify code review. Write side change is at https://github.com/apache/spark/pull/23106 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes #23105 from rxin/SPARK-26140. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../spark/executor/ShuffleReadMetrics.scala | 18 ++++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 4 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 4 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 7 ++- .../shuffle/BlockStoreShuffleReader.scala | 5 +- .../apache/spark/shuffle/ShuffleManager.scala | 3 +- .../shuffle/ShuffleMetricsReporter.scala | 33 ++++++++++++ .../org/apache/spark/shuffle/metrics.scala | 52 +++++++++++++++++++ .../shuffle/sort/SortShuffleManager.scala | 6 ++- .../storage/ShuffleBlockFetcherIterator.scala | 10 ++-- .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++- .../spark/scheduler/CustomShuffledRDD.scala | 3 +- .../BlockStoreShuffleReaderSuite.scala | 5 +- .../ShuffleBlockFetcherIteratorSuite.scala | 31 +++++++---- .../spark/sql/execution/ShuffledRowRDD.scala | 4 +- 15 files changed, 155 insertions(+), 36 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/metrics.scala diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 4be395c8358b2..2f97e969d2dd2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleMetricsReporter import org.apache.spark.util.LongAccumulator @@ -123,12 +124,13 @@ class ShuffleReadMetrics private[spark] () extends Serializable { } } + /** * A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at * last. */ -private[spark] class TempShuffleReadMetrics { +private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L @@ -137,13 +139,13 @@ private[spark] class TempShuffleReadMetrics { private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L - def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v - def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v - def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v - def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v - def incLocalBytesRead(v: Long): Unit = _localBytesRead += v - def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v - def incRecordsRead(v: Long): Unit = _recordsRead += v + override def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v + override def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v + override def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + override def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v + override def incLocalBytesRead(v: Long): Unit = _localBytesRead += v + override def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v + override def incRecordsRead(v: Long): Unit = _recordsRead += v def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4574c3724962e..7e76731f5e454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -143,8 +143,10 @@ class CoGroupedRDD[K: ClassTag]( case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val it = SparkEnv.get.shuffleManager - .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) + .getReader( + shuffleDependency.shuffleHandle, split.index, split.index + 1, context, metrics) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index e8f9b27b7eb55..5ec99b7f4f3ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -101,7 +101,9 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] - SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + val metrics = context.taskMetrics().createTempShuffleReadMetrics() + SparkEnv.get.shuffleManager.getReader( + dep.shuffleHandle, split.index, split.index + 1, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index a733eaa5d7e53..42d190377f104 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -107,9 +107,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) case shuffleDependency: ShuffleDependency[_, _, _] => + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val iter = SparkEnv.get.shuffleManager .getReader( - shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + shuffleDependency.shuffleHandle, + partition.index, + partition.index + 1, + context, + metrics) .read() iter.foreach(op) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..7cb031ce318b7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + readMetrics: ShuffleMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) @@ -53,7 +54,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics) val serializerInstance = dep.serializer.newInstance() @@ -66,7 +68,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 4ea8a7120a9cc..d1061d83cb85a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -48,7 +48,8 @@ private[spark] trait ShuffleManager { handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] + context: TaskContext, + metrics: ShuffleMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala new file mode 100644 index 0000000000000..32865149c97c2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An interface for reporting shuffle information, for each shuffle. This interface assumes + * all the methods are called on a single-threaded, i.e. concrete implementations would not need + * to synchronize anything. + */ +private[spark] trait ShuffleMetricsReporter { + def incRemoteBlocksFetched(v: Long): Unit + def incLocalBlocksFetched(v: Long): Unit + def incRemoteBytesRead(v: Long): Unit + def incRemoteBytesReadToDisk(v: Long): Unit + def incLocalBytesRead(v: Long): Unit + def incFetchWaitTime(v: Long): Unit + def incRecordsRead(v: Long): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/metrics.scala b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala new file mode 100644 index 0000000000000..33be677bc90cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An interface for reporting shuffle read metrics, for each shuffle. This interface assumes + * all the methods are called on a single-threaded, i.e. concrete implementations would not need + * to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleReadMetricsReporter { + private[spark] def incRemoteBlocksFetched(v: Long): Unit + private[spark] def incLocalBlocksFetched(v: Long): Unit + private[spark] def incRemoteBytesRead(v: Long): Unit + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit + private[spark] def incLocalBytesRead(v: Long): Unit + private[spark] def incFetchWaitTime(v: Long): Unit + private[spark] def incRecordsRead(v: Long): Unit +} + + +/** + * An interface for reporting shuffle write metrics. This interface assumes all the methods are + * called on a single-threaded, i.e. concrete implementations would not need to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleWriteMetricsReporter { + private[spark] def incBytesWritten(v: Long): Unit + private[spark] def incRecordsWritten(v: Long): Unit + private[spark] def incWriteTime(v: Long): Unit + private[spark] def decBytesWritten(v: Long): Unit + private[spark] def decRecordsWritten(v: Long): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..57c3150e5a697 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -114,9 +114,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { + context: TaskContext, + metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, endPartition, context, metrics) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..a2e0713e70b04 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter} import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -51,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. Note that zero-sized blocks are * already excluded, which happened in - * [[MapOutputTracker.convertMapStatuses]]. + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -59,6 +59,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param shuffleMetrics used to report shuffle metrics. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -71,7 +72,8 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) + detectCorrupt: Boolean, + shuffleMetrics: ShuffleMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -137,8 +139,6 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b917469e48747..419a26b857ea2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -397,8 +397,10 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } - val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + val taskContext = new TaskContextImpl( + 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala index 838686923767e..1be2e2a067115 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala @@ -104,8 +104,9 @@ class CustomShuffledRDD[K, V, C]( override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = { val part = p.asInstanceOf[CustomShuffledRDDPartition] + val metrics = context.taskMetrics().createTempShuffleReadMetrics() SparkEnv.get.shuffleManager.getReader( - dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context) + dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..eb97d5a1e5074 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -126,11 +126,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set("spark.shuffle.compress", "false") .set("spark.shuffle.spill.compress", "false")) + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - TaskContext.empty(), + taskContext, + metrics, serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..01ee9ef0825f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -102,8 +102,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ).toIterator + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val iterator = new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -112,7 +114,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + metrics) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -190,7 +193,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -258,7 +262,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -328,7 +333,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -392,7 +398,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next()._1, iterator.next()._1) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) @@ -446,7 +453,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - false) + false, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -496,8 +504,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + val taskContext = TaskContext.empty() new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -506,7 +515,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, - detectCorrupt = true) + detectCorrupt = true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( @@ -552,7 +562,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 862ee05392f37..542266bc1ae07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -154,6 +154,7 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + val metrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -161,7 +162,8 @@ class ShuffledRowRDD( dependency.shuffleHandle, shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, - context) + context, + metrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } From 7f5f7a967d36d78f73d8fa1e178dfdb324d73bf1 Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 24 Nov 2018 09:10:15 -0600 Subject: [PATCH 113/145] [SPARK-25786][CORE] If the ByteBuffer.hasArray is false , it will throw UnsupportedOperationException for Kryo ## What changes were proposed in this pull request? `deserialize` for kryo, the type of input parameter is ByteBuffer, if it is not backed by an accessible byte array. it will throw `UnsupportedOperationException` Exception Info: ``` java.lang.UnsupportedOperationException was thrown. java.lang.UnsupportedOperationException at java.nio.ByteBuffer.array(ByteBuffer.java:994) at org.apache.spark.serializer.KryoSerializerInstance.deserialize(KryoSerializer.scala:362) ``` ## How was this patch tested? Added a unit test Closes #22779 from 10110346/InputStreamKryo. Authored-by: liuxian Signed-off-by: Sean Owen --- .../apache/spark/serializer/KryoSerializer.scala | 16 +++++++++++++--- .../spark/serializer/KryoSerializerSuite.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 66812a54846c6..1e1c27c477877 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -42,7 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, ByteBufferInputStream, SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer /** @@ -417,7 +417,12 @@ private[spark] class KryoSerializerInstance( override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -429,7 +434,12 @@ private[spark] class KryoSerializerInstance( val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e413fe3b774d0..a7eed4b6a8b88 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.ByteBuffer import java.util.concurrent.Executors import scala.collection.JavaConverters._ @@ -551,6 +552,17 @@ class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSpar deserializationStream.close() assert(serInstance.deserialize[Any](helloHello) === ((hello, hello))) } + + test("SPARK-25786: ByteBuffer.array -- UnsupportedOperationException") { + val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + val obj = "UnsupportedOperationException" + val serObj = serInstance.serialize(obj) + val byteBuffer = ByteBuffer.allocateDirect(serObj.array().length) + byteBuffer.put(serObj.array()) + byteBuffer.flip() + assert(serInstance.deserialize[Any](serObj) === (obj)) + assert(serInstance.deserialize[Any](byteBuffer) === (obj)) + } } class ClassLoaderTestingObject From 0f56977f8c9bfc48230d499925e31ff81bcd0f86 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 24 Nov 2018 09:12:05 -0600 Subject: [PATCH 114/145] [SPARK-26156][WEBUI] Revise summary section of stage page ## What changes were proposed in this pull request? In the summary section of stage page: ![image](https://user-images.githubusercontent.com/1097932/48935518-ebef2b00-ef42-11e8-8672-eaa4cac92c5e.png) 1. the following metrics names can be revised: Output => Output Size / Records Shuffle Read: => Shuffle Read Size / Records Shuffle Write => Shuffle Write Size / Records After changes, the names are more clear, and consistent with the other names in the same page. 2. The associated job id URL should not contain the 3 tails spaces. Reduce the number of spaces to one, and exclude the space from link. This is consistent with SQL execution page. ## How was this patch tested? Manual check: ![image](https://user-images.githubusercontent.com/1097932/48935538-f7425680-ef42-11e8-8b2a-a4f388d3ea52.png) Closes #23125 from gengliangwang/reviseStagePage. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- .../org/apache/spark/ui/jobs/StagePage.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7e6cc4297d6b1..2b436b9234144 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -152,20 +152,20 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We }} {if (hasOutput(stageData)) {
  • - Output: + Output Size / Records: {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
  • }} {if (hasShuffleRead(stageData)) {
  • - Shuffle Read: + Shuffle Read Size / Records: {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + s"${stageData.shuffleReadRecords}"}
  • }} {if (hasShuffleWrite(stageData)) {
  • - Shuffle Write: + Shuffle Write Size / Records: {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + s"${stageData.shuffleWriteRecords}"}
  • @@ -183,10 +183,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We {if (!stageJobIds.isEmpty) {
  • Associated Job Ids: - {stageJobIds.map(jobId => {val detailUrl = "%s/jobs/job/?id=%s".format( - UIUtils.prependBaseUri(request, parent.basePath), jobId) - {s"${jobId}"}    - })} + {stageJobIds.sorted.map { jobId => + val jobURL = "%s/jobs/job/?id=%s" + .format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + {jobId.toString}  + }}
  • }} From eea4a0330b913cd45e369f09ec3d1dbb1b81f1b5 Mon Sep 17 00:00:00 2001 From: Lee moon soo Date: Sat, 24 Nov 2018 16:09:13 -0800 Subject: [PATCH 115/145] [MINOR][K8S] Invalid property "spark.driver.pod.name" is referenced in docs. ## What changes were proposed in this pull request? "Running on Kubernetes" references `spark.driver.pod.name` few places, and it should be `spark.kubernetes.driver.pod.name`. ## How was this patch tested? See changes Closes #23133 from Leemoonsoo/fix-driver-pod-name-prop. Authored-by: Lee moon soo Signed-off-by: Dongjoon Hyun --- docs/running-on-kubernetes.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a9d448820e700..e940d9a63b7af 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -166,7 +166,7 @@ hostname via `spark.driver.host` and your spark driver's port to `spark.driver.p ### Client Mode Executor Pod Garbage Collection -If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +If you run your Spark driver in a pod, it is highly recommended to set `spark.kubernetes.driver.pod.name` to the name of that pod. When this property is set, the Spark scheduler will deploy the executor pods with an [OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. @@ -175,7 +175,7 @@ an OwnerReference pointing to that pod will be added to each executor pod's Owne setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated prematurely when the wrong pod is deleted. -If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +If your application is not running inside a pod, or if `spark.kubernetes.driver.pod.name` is not set when your application is actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the From 41d5aaec840234b1fcfd6f87f5e9e7729a3f0fe2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 26 Nov 2018 00:26:24 +0900 Subject: [PATCH 116/145] [SPARK-26148][PYTHON][TESTS] Increases default parallelism in PySpark tests to speed up ## What changes were proposed in this pull request? This PR proposes to increase parallelism in PySpark tests to speed up from 4 to 8. It decreases the elapsed time from https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99163/consoleFull Tests passed in 1770 seconds to https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99186/testReport/ Tests passed in 1027 seconds ## How was this patch tested? Jenkins tests Closes #23111 from HyukjinKwon/parallelism. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 7ec73347d16bf..27f7527052e29 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -460,7 +460,7 @@ def parse_opts(): prog="run-tests" ) parser.add_option( - "-p", "--parallelism", type="int", default=4, + "-p", "--parallelism", type="int", default=8, help="The number of suites to test in parallel (default %default)" ) From c5daccb1dafca528ccb4be65d63c943bf9a7b0f2 Mon Sep 17 00:00:00 2001 From: Katrin Leinweber <9948149+katrinleinweber@users.noreply.github.com> Date: Sun, 25 Nov 2018 17:43:55 -0600 Subject: [PATCH 117/145] [MINOR] Update all DOI links to preferred resolver ## What changes were proposed in this pull request? The DOI foundation recommends [this new resolver](https://www.doi.org/doi_handbook/3_Resolution.html#3.8). Accordingly, this PR re`sed`s all static DOI links ;-) ## How was this patch tested? It wasn't, since it seems as safe as a "[typo fix](https://spark.apache.org/contributing.html)". In case any of the files is included from other projects, and should be updated there, please let me know. Closes #23129 from katrinleinweber/resolve-DOIs-securely. Authored-by: Katrin Leinweber <9948149+katrinleinweber@users.noreply.github.com> Signed-off-by: Sean Owen --- R/pkg/R/stats.R | 4 ++-- .../scala/org/apache/spark/api/java/JavaPairRDD.scala | 6 +++--- .../scala/org/apache/spark/api/java/JavaRDDLike.scala | 2 +- .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 ++++---- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- docs/ml-classification-regression.md | 4 ++-- docs/ml-collaborative-filtering.md | 4 ++-- docs/ml-frequent-pattern-mining.md | 8 ++++---- docs/mllib-collaborative-filtering.md | 4 ++-- docs/mllib-frequent-pattern-mining.md | 6 +++--- docs/mllib-isotonic-regression.md | 4 ++-- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- .../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 2 +- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 4 ++-- .../scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- .../spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- .../org/apache/spark/mllib/recommendation/ALS.scala | 2 +- python/pyspark/ml/fpm.py | 6 +++--- python/pyspark/ml/recommendation.py | 2 +- python/pyspark/mllib/fpm.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/rdd.py | 2 +- python/pyspark/sql/dataframe.py | 4 ++-- .../spark/sql/catalyst/util/QuantileSummaries.scala | 2 +- .../org/apache/spark/sql/DataFrameStatFunctions.scala | 10 +++++----- .../spark/sql/execution/stat/FrequentItems.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- 29 files changed, 54 insertions(+), 54 deletions(-) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 497f18c763048..7252351ebebb2 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -109,7 +109,7 @@ setMethod("corr", #' #' Finding frequent items for columns, possibly with false positives. #' Using the frequent element count algorithm described in -#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' \url{https://doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. #' #' @param x A SparkDataFrame. #' @param cols A vector column names to search frequent items in. @@ -143,7 +143,7 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' *exact* rank of x is close to (p * N). More precisely, #' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed -#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' optimizations). The algorithm was first present in [[https://doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. #' Note that NA values will be ignored in numerical columns before calculation. For #' columns only containing NA values, an empty list is returned. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 80a4f84087466..50ed8d9bd3f68 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -952,7 +952,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -969,7 +969,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -985,7 +985,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 91ae1002abd21..5ba821935ac69 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -685,7 +685,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e68c6b1366c7f..4bf4f082d0382 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -394,7 +394,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is * greater than `p`) would trigger sparse representation of registers, which may reduce the @@ -436,7 +436,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -456,7 +456,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -473,7 +473,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 743e3441eea55..6a25ee20b2c68 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1258,7 +1258,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater * than `p`) would trigger sparse representation of registers, which may reduce the memory @@ -1290,7 +1290,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index b3d109039da4d..42912a2e2bc31 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -941,9 +941,9 @@ Essentially isotonic regression is a best fitting the original data points. We implement a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 8b0f287dc39ad..58646642bfbcc 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -41,7 +41,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.ml` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -55,7 +55,7 @@ We scale the regularization parameter `regParam` in solving each least squares p the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md index c2043d495c149..f613664271ec6 100644 --- a/docs/ml-frequent-pattern-mining.md +++ b/docs/ml-frequent-pattern-mining.md @@ -18,7 +18,7 @@ for more information. ## FP-Growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -26,7 +26,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence is more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -90,7 +90,7 @@ Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. @@ -137,4 +137,4 @@ Refer to the [R API docs](api/R/spark.prefixSpan.html) for more details. {% include_example r/ml/prefixSpan.R %} - \ No newline at end of file + diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index b2300028e151b..aeebb26bb45f3 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -37,7 +37,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -51,7 +51,7 @@ Since v1.1, we scale the regularization parameter `lambda` in solving each least the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `lambda` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 0d3192c6b1d9c..8e4505756b275 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -15,7 +15,7 @@ a popular algorithm to mining frequent itemsets. ## FP-growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -23,7 +23,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -122,7 +122,7 @@ Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/ PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 99cab98c690c6..9964fce3273be 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -24,9 +24,9 @@ Essentially isotonic regression is a best fitting the original data points. `spark.mllib` supports a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 919496aa1a840..2eed84d51782a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -263,7 +263,7 @@ object KMeansModel extends MLReadable[KMeansModel] { /** * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. * - * @see Bahmani et al., Scalable k-means++. + * @see Bahmani et al., Scalable k-means++. */ @Since("1.5.0") class KMeans @Since("1.5.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 840a89b76d26b..7322815c12ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -118,10 +118,10 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * :: Experimental :: * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index bd1c1a8885201..2a3413553a6af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to * run the PrefixSpan algorithm. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index ffe592789b3cc..50ef4330ddc80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -557,7 +557,7 @@ object ALSModel extends MLReadable[ALSModel] { * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. + * https://doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3a1bc35186dc3..519c1ea47c1db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -152,10 +152,10 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { /** * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. * * @param minSupport the minimal support level of the frequent pattern, any pattern that appears diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 64d6a0bc47b97..b2c09b408b40b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 82ab716ed96a8..c12b751bfb8e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -540,7 +540,7 @@ class RowMatrix @Since("1.0.0") ( * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. * Reference: * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce - * architectures" (see here) + * architectures" (see here) * * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 14288221b6945..12870f819b147 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -54,7 +54,7 @@ case class Rating @Since("0.8.0") ( * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * here, adapted for the blocked approach + * here, adapted for the blocked approach * used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 886ad8409ca66..734763ebd3fa6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -167,8 +167,8 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, independent group of mining tasks. The FP-Growth algorithm is described in Han et al., Mining frequent patterns without candidate generation [HAN2000]_ - .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 - .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + .. [LI2008] https://doi.org/10.1145/1454008.1454027 + .. [HAN2000] https://doi.org/10.1145/335191.335372 .. note:: null values in the feature column are ignored during fit(). .. note:: Internally `transform` `collects` and `broadcasts` association rules. @@ -254,7 +254,7 @@ class PrefixSpan(JavaParams): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - (see here). + (see here). This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` method to run the PrefixSpan algorithm. diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index a8eae9bd268d3..520d7912c1a10 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -57,7 +57,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha For implicit preference data, the algorithm used is based on `"Collaborative Filtering for Implicit Feedback Datasets", - `_, adapted for the blocked + `_, adapted for the blocked approach used here. Essentially instead of finding the low-rank approximations to the diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index de18dad1f675d..6accb9b4926e8 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -132,7 +132,7 @@ class PrefixSpan(object): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - ([[http://doi.org/10.1109/ICDE.2001.914830]]). + ([[https://doi.org/10.1109/ICDE.2001.914830]]). .. versionadded:: 1.6.0 """ diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 7e8b15056cabe..b7f09782be9dd 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -270,7 +270,7 @@ def tallSkinnyQR(self, computeQ=False): Reference: Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce architectures" - ([[http://dx.doi.org/10.1145/1996092.1996103]]) + ([[https://doi.org/10.1145/1996092.1996103]]) :param: computeQ: whether to computeQ :return: QRDecomposition(Q: RowMatrix, R: Matrix), where diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ccf39e1ffbe96..8bd6897df925f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2354,7 +2354,7 @@ def countApproxDistinct(self, relativeSD=0.05): The algorithm used is based on streamlib's implementation of `"HyperLogLog in Practice: Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available here - `_. + `_. :param relativeSD: Relative accuracy. Smaller values create counters that require more space. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c4f4d81999544..4abbeacfd56b4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1806,7 +1806,7 @@ def approxQuantile(self, col, probabilities, relativeError): This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). The algorithm was first - present in [[http://dx.doi.org/10.1145/375663.375670 + present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. @@ -1928,7 +1928,7 @@ def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in - "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". + "https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. .. note:: This function is meant for exploratory data analysis, as we make no diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 3190e511e2cb5..2a03f85ab594b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * Helper class to compute approximate quantile summary. * This implementation is based on the algorithm proposed in the paper: * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael - * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) + * and Khanna, Sanjeev. (https://doi.org/10.1145/375663.375670) * * In order to optimize for speed, it maintains an internal buffer of the last seen samples, * and only inserts them after crossing a certain size threshold. This guarantees a near-constant diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b2f6a6ba83108..0b22b898557f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -51,7 +51,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param col the name of the numerical column @@ -218,7 +218,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * The `support` should be greater than 1e-4. * @@ -265,7 +265,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * Uses a `default` support of 1%. * @@ -284,7 +284,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * * This function is meant for exploratory data analysis, as we make no guarantee about the @@ -328,7 +328,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * Uses a `default` support of 1%. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 86f6307254332..420faa6f24734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -69,7 +69,7 @@ object FrequentItems extends Logging { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * The `support` should be greater than 1e-4. * For Internal use only. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index bea652cc33076..ac25a8fd90bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -45,7 +45,7 @@ object StatFunctions extends Logging { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param df the dataframe From 94145786a5b91a7f0bca44f27599a61c72f3a18f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 25 Nov 2018 15:53:07 -0800 Subject: [PATCH 118/145] [SPARK-25908][SQL][FOLLOW-UP] Add back unionAll ## What changes were proposed in this pull request? This PR is to add back `unionAll`, which is widely used. The name is also consistent with our ANSI SQL. We also have the corresponding `intersectAll` and `exceptAll`, which were introduced in Spark 2.4. ## How was this patch tested? Added a test case in DataFrameSuite Closes #23131 from gatorsmile/addBackUnionAll. Authored-by: gatorsmile Signed-off-by: gatorsmile --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 14 ++++++++++++++ R/pkg/R/generics.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 2 +- docs/sql-migration-guide-upgrade.md | 2 ++ python/pyspark/sql/dataframe.py | 11 +++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 14 ++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 9 files changed, 53 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index de56061b4c1c7..cdeafdd90ce4a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -169,6 +169,7 @@ exportMethods("arrange", "toJSON", "transform", "union", + "unionAll", "unionByName", "unique", "unpersist", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 52e76570139e2..ad9cd845f696c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2732,6 +2732,20 @@ setMethod("union", dataFrame(unioned) }) +#' Return a new SparkDataFrame containing the union of rows +#' +#' This is an alias for `union`. +#' +#' @rdname union +#' @name unionAll +#' @aliases unionAll,SparkDataFrame,SparkDataFrame-method +#' @note unionAll since 1.4.0 +setMethod("unionAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + union(x, y) + }) + #' Return a new SparkDataFrame containing the union of rows, matched by column names #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index cbed276274ac1..b2ca6e62175e7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -631,6 +631,9 @@ setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union setGeneric("union", function(x, y) { standardGeneric("union") }) +#' @rdname union +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + #' @rdname unionByName setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index f355a515935c8..77a29c9ecad86 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2458,6 +2458,7 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF expect_equal(count(unioned), 6) expect_equal(first(unioned)$name, "Michael") expect_equal(count(arrange(suppressWarnings(union(df, df2)), df$age)), 6) + expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6) df1 <- select(df2, "age", "name") unioned1 <- arrange(unionByName(df1, df), df1$age) diff --git a/docs/sparkr.md b/docs/sparkr.md index acd0e77c4d71a..5972435a0e409 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -718,4 +718,4 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 3.0.0 - The deprecated methods `sparkR.init`, `sparkRSQL.init`, `sparkRHive.init` have been removed. Use `sparkR.session` instead. - - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, `dropTempTable`, `unionAll` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead. + - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, and `dropTempTable` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead. diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 397ca59d96497..68cb8f5a0d18c 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -9,6 +9,8 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.4 to 3.0 + - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`. + - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4abbeacfd56b4..ca15b36699166 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1448,6 +1448,17 @@ def union(self, other): """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) + @since(1.3) + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by :func:`distinct`. + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return self.union(other) + @since(2.3) def unionByName(self, other): """ Returns a new :class:`DataFrame` containing union of rows in this and another frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e757921b485df..f361bde281732 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1852,6 +1852,20 @@ class Dataset[T] private[sql]( CombineUnions(Union(logicalPlan, other.logicalPlan)) } + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * This is an alias for `union`. + * + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.0.0 + */ + def unionAll(other: Dataset[T]): Dataset[T] = union(other) + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0ee2627814ba0..7a0767a883f15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -97,6 +97,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { unionDF.agg(avg('key), max('key), min('key), sum('key)), Row(50.5, 100, 1, 25250) :: Nil ) + + // unionAll is an alias of union + val unionAllDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + checkAnswer(unionDF, unionAllDF) } test("union should union DataFrames with UDTs (SPARK-13410)") { From 6339c8c2c6b80a85e4ad6a7fa7595cf567a1113e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Nov 2018 11:13:28 +0800 Subject: [PATCH 119/145] [SPARK-24762][SQL] Enable Option of Product encoders ## What changes were proposed in this pull request? SparkSQL doesn't support to encode `Option[Product]` as a top-level row now, because in SparkSQL entire top-level row can't be null. However for use cases like Aggregator, it is reasonable to use `Option[Product]` as buffer and output column types. Due to above limitation, we don't do it for now. This patch proposes to encode `Option[Product]` at top-level as single struct column. So we can work around the issue that entire top-level row can't be null. To summarize encoding of `Product` and `Option[Product]`. For `Product`, 1. at root level, the schema is all fields are flatten it into multiple columns. The `Product ` can't be null, otherwise it throws an exception. ```scala val df = Seq((1 -> "a"), (2 -> "b")).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: string (nullable = true) ``` 2. At non-root level, `Product` is a struct type column. ```scala val df = Seq((1, (1 -> "a")), (2, (2 -> "b")), (3, null)).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` For `Option[Product]`, 1. it was not supported at root level. After this change, it is a struct type column. ```scala val df = Seq(Some(1 -> "a"), Some(2 -> "b"), None).toDF() df.printSchema root |-- value: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 2. At non-root level, it is also a struct type column. ```scala val df = Seq((1, Some(1 -> "a")), (2, Some(2 -> "b")), (3, None)).toDF() df.printSchema root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 3. For use case like Aggregator, it was not supported too. After this change, we support to use `Option[Product]` as buffer/output column type. ```scala val df = Seq( OptionBooleanIntData("bob", Some((true, 1))), OptionBooleanIntData("bob", Some((false, 2))), OptionBooleanIntData("bob", None)).toDF() val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) group.printSchema root |-- name: string (nullable = true) |-- isGood: struct (nullable = true) | |-- _1: boolean (nullable = false) | |-- _2: integer (nullable = false) ``` The buffer and output type of `OptionBooleanIntAggregator` is both `Option[(Boolean, Int)`. ## How was this patch tested? Added test. Closes #21732 from viirya/SPARK-24762. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../catalyst/encoders/ExpressionEncoder.scala | 32 +++++--- .../scala/org/apache/spark/sql/Dataset.scala | 10 +-- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 18 ++--- .../spark/sql/DatasetAggregatorSuite.scala | 64 ++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 77 ++++++++++++++++--- 6 files changed, 163 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 592520c59a761..d019924711e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -49,15 +49,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) @@ -198,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -213,6 +204,9 @@ case class ExpressionEncoder[T]( } else { // For other input objects like primitive, array, map, etc., we construct a struct to wrap // the serializer which is a column of an row. + // + // Note: Because Spark SQL doesn't allow top-level row to be null, to encode + // top-level Option[Product] type, we make it as a top-level struct column. CreateNamedStruct(Literal("value") :: objSerializer :: Nil) } }.flatten @@ -226,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -253,10 +247,24 @@ case class ExpressionEncoder[T]( }) /** - * Returns true if the type `T` is serialized as a struct. + * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + /** + * Returns true if the type `T` is an `Option` type. + */ + def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + + /** + * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in + * the struct are naturally mapped to top-level columns in a row. In other words, the serialized + * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be + * flattened to top-level row, because in Spark SQL top-level row can't be null. This method + * returns true if `T` is serialized as struct and is not `Option` type. + */ + def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f361bde281732..b10d66dfb1aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1093,7 +1093,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct) { + val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1103,7 +1103,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct) { + val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1116,14 +1116,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct) { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct) { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1396,7 +1396,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct) { + if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 2d849c65997a7..a3cbea9021f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -458,7 +458,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct) { + val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { assert(groupingAttributes.length == 1) if (SQLConf.get.nameNonStructGroupingKeyAsValue) { groupingAttributes.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 39200ec00e152..b75752945a492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -40,9 +40,9 @@ object TypedAggregateExpression { val outputEncoder = encoderFor[OUT] val outputType = outputEncoder.objSerializer.dataType - // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer - // expression is an alias of `BoundReference`, which means the buffer object doesn't need - // serialization. + // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct + // and the serializer expression is an alias of `BoundReference`, which means the buffer + // object doesn't need serialization. val isSimpleBuffer = { bufferSerializer.head match { case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true @@ -76,7 +76,7 @@ object TypedAggregateExpression { None, bufferSerializer, bufferEncoder.resolveAndBind().deserializer, - outputEncoder.serializer, + outputEncoder.objSerializer, outputType, outputEncoder.objSerializer.nullable) } @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, - outputSerializer: Seq[Expression], + outputSerializer: Expression, dataType: DataType, nullable: Boolean, mutableAggBufferOffset: Int = 0, @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( aggregator.merge(buffer, input) } - private lazy val resultObjToRow = dataType match { - case _: StructType => - UnsafeProjection.create(CreateStruct(outputSerializer)) - case _ => - assert(outputSerializer.length == 1) - UnsafeProjection.create(outputSerializer.head) - } + private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) override def eval(buffer: Any): Any = { val resultObj = aggregator.finish(buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40e..97c3f358c0e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + + val expectedSchema = new StructType() + .add("name", StringType, nullable = true) + .add("isGood", + new StructType() + .add("_1", BooleanType, nullable = false) + .add("_2", IntegerType, nullable = false), + nullable = true) + + assert(df.schema == expectedSchema) + assert(group.schema == expectedSchema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index baece2ddac7eb..0f900833d2cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1312,15 +1312,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1558,6 +1549,74 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(Row("Amsterdam"))) } + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) + + assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) + } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS() + .as[(Int, Option[(String, Double)])] + checkDataset(ds, + (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, + Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, + None, None, Some((3, 4))) + } + + test("SPARK-24762: joinWith on Option[Product]") { + val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") + checkDataset(joined, (Some((2, 3)), Some((1, 2)))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert(ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } + test("SPARK-25942: typed aggregation on primitive type") { val ds = Seq(1, 2, 3).toDS() From 6ab8485da21035778920da0d9332709f9acaff45 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 15:47:04 +0800 Subject: [PATCH 120/145] [SPARK-26169] Create DataFrameSetOperationsSuite ## What changes were proposed in this pull request? Create a new suite DataFrameSetOperationsSuite for the test cases of DataFrame/Dataset's set operations. Also, add test cases of NULL handling for Array Except and Array Intersect. ## How was this patch tested? N/A Closes #23137 from gatorsmile/setOpsTest. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../CollectionExpressionsSuite.scala | 26 + .../sql/DataFrameSetOperationsSuite.scala | 509 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 478 ---------------- 3 files changed, 535 insertions(+), 478 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1415b7da6fca1..d2edb2f24688d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1658,6 +1658,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + test("Array Except - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayExcept(oneNull, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, twoNulls), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(empty, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(oneNull, empty), Seq(null)) + checkEvaluation(ArrayExcept(twoNulls, empty), Seq(null)) + } + test("Array Intersect") { val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) @@ -1769,4 +1782,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayIntersect(oneNull, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(oneNull, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(empty, oneNull), Seq.empty) + checkEvaluation(ArrayIntersect(oneNull, empty), Seq.empty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala new file mode 100644 index 0000000000000..30452af1fad64 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.NullStrings +import org.apache.spark.sql.types._ + +class DataFrameSetOperationsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.except(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.except(nullInts), + Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.except(allNulls.filter("0 = 1")), + Row(null) :: Nil) + checkAnswer( + allNulls.except(allNulls), + Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.except(df.filter("0 = 1")), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").except(allNulls), + Nil) + } + + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + + test("except distinct - SQL compliance") { + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 3).toDF("id") + + checkAnswer( + df_left.except(df_right), + Row(2) :: Row(4) :: Nil + ) + } + + test("except - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.except(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.except(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.except(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.except(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + } + + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.union(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } + + test("union all") { + val unionDF = testData.union(testData).union(testData) + .union(testData).union(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + + // unionAll is an alias of union + val unionAllDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + checkAnswer(unionDF, unionAllDF) + } + + test("union should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.union(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7a0767a883f15..fc3faa08d55f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -85,129 +85,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } - test("union all") { - val unionDF = testData.union(testData).union(testData) - .union(testData).union(testData) - - // Before optimizer, Union should be combined. - assert(unionDF.queryExecution.analyzed.collect { - case j: Union if j.children.size == 5 => j }.size === 1) - - checkAnswer( - unionDF.agg(avg('key), max('key), min('key), sum('key)), - Row(50.5, 100, 1, 25250) :: Nil - ) - - // unionAll is an alias of union - val unionAllDF = testData.unionAll(testData).unionAll(testData) - .unionAll(testData).unionAll(testData) - - checkAnswer(unionDF, unionAllDF) - } - - test("union should union DataFrames with UDTs (SPARK-13410)") { - val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) - val schema1 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) - val schema2 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val df1 = spark.createDataFrame(rowRDD1, schema1) - val df2 = spark.createDataFrame(rowRDD2, schema2) - - checkAnswer( - df1.union(df2).orderBy("label"), - Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) - ) - } - - test("union by name") { - var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") - val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") - val unionDf = df1.unionByName(df2.unionByName(df3)) - checkAnswer(unionDf, - Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil - ) - - // Check if adjacent unions are combined into a single one - assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) - - // Check failure cases - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains( - "Union can only be performed on tables with the same number of columns, " + - "but the first table has 2 columns and the second table has 3 columns")) - - df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) - } - - test("union by name - type coercion") { - var df1 = Seq((1, "a")).toDF("c0", "c1") - var df2 = Seq((3, 1L)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) - - df1 = Seq((1, 1.0)).toDF("c0", "c1") - df2 = Seq((8L, 3.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) - - df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") - df2 = Seq(("a", 4.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) - - df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") - df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") - val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") - checkAnswer(df1.unionByName(df2.unionByName(df3)), - Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil - ) - } - - test("union by name - check case sensitivity") { - def checkCaseSensitiveTest(): Unit = { - val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") - val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") - checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val errMsg2 = intercept[AnalysisException] { - checkCaseSensitiveTest() - }.getMessage - assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkCaseSensitiveTest() - } - } - - test("union by name - check name duplication") { - Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - var df1 = Seq((1, 1)).toDF(c0, c1) - var df2 = Seq((1, 1)).toDF("c0", "c1") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) - df1 = Seq((1, 1)).toDF("c0", "c1") - df2 = Seq((1, 1)).toDF(c0, c1) - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) - } - } - } - test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) @@ -528,259 +405,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.except(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.except(nullInts), - Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.except(allNulls.filter("0 = 1")), - Row(null) :: Nil) - checkAnswer( - allNulls.except(allNulls), - Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.except(df.filter("0 = 1")), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").except(allNulls), - Nil) - } - - test("SPARK-23274: except between two projects without references used in filter") { - val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") - val df1 = df.filter($"a" === 1) - val df2 = df.filter($"a" === 2) - checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) - checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) - } - - test("except distinct - SQL compliance") { - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 3).toDF("id") - - checkAnswer( - df_left.except(df_right), - Row(2) :: Row(4) :: Nil - ) - } - - test("except - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.except(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.except(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.except(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.except(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("except all") { - checkAnswer( - lowerCaseData.exceptAll(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) - checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.exceptAll(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.exceptAll(nullInts), - Nil) - - // check that duplicate values are preserved - checkAnswer( - allNulls.exceptAll(allNulls.filter("0 = 1")), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - checkAnswer( - allNulls.exceptAll(allNulls.limit(2)), - Row(null) :: Row(null) :: Nil) - - // check that duplicates are retained. - val df = spark.sparkContext.parallelize( - NullStrings(1, "id1") :: - NullStrings(1, "id1") :: - NullStrings(2, "id1") :: - NullStrings(3, null) :: Nil).toDF("id", "value") - - checkAnswer( - df.exceptAll(df.filter("0 = 1")), - Row(1, "id1") :: - Row(1, "id1") :: - Row(2, "id1") :: - Row(3, null) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").exceptAll(allNulls), - Nil) - - } - - test("exceptAll - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.exceptAll(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.exceptAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.exceptAll(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.exceptAll(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersect(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.intersect(allNulls), - Row(null) :: Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.intersect(df), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - } - - test("intersect - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersect(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersect(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersect(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersect(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersectAll") { - checkAnswer( - lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), - Row(1, "a") :: - Row(2, "b") :: - Row(2, "b") :: - Row(3, "c") :: - Row(3, "c") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersectAll(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // Duplicate nulls are preserved. - checkAnswer( - allNulls.intersectAll(allNulls), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 2, 2, 3).toDF("id") - - checkAnswer( - df_left.intersectAll(df_right), - Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) - } - - test("intersectAll - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersectAll(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersectAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersectAll(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersectAll(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -1782,56 +1406,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val df1 = (1 to 100).map(Tuple1.apply).toDF("i") - val df2 = (1 to 30).map(Tuple1.apply).toDF("i") - val intersect = df1.intersect(df2) - val except = df1.except(df2) - assert(intersect.count() === 30) - assert(except.count() === 70) - } - - test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { - val df1 = (1 to 20).map(Tuple1.apply).toDF("i") - val df2 = (1 to 10).map(Tuple1.apply).toDF("i") - - // When generating expected results at here, we need to follow the implementation of - // Rand expression. - def expected(df: DataFrame): Seq[Row] = { - df.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.filter(_.getInt(0) < rng.nextDouble() * 10) - } - } - - val union = df1.union(df2) - checkAnswer( - union.filter('i < rand(7) * 10), - expected(union) - ) - checkAnswer( - union.select(rand(7)), - union.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.map(_ => rng.nextDouble()).map(i => Row(i)) - } - ) - - val intersect = df1.intersect(df2) - checkAnswer( - intersect.filter('i < rand(7) * 10), - expected(intersect) - ) - - val except = df1.except(df2) - checkAnswer( - except.filter('i < rand(7) * 10), - expected(except) - ) - } - test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") @@ -2280,21 +1854,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-17123: Performing set operations that combine non-scala native types") { - val dates = Seq( - (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), - (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) - ).toDF("date", "timestamp", "decimal") - - val widenTypedRows = Seq( - (new Timestamp(2), 10.5D, "string") - ).toDF("date", "timestamp", "decimal") - - dates.union(widenTypedRows).collect() - dates.except(widenTypedRows).collect() - dates.intersect(widenTypedRows).collect() - } - test("SPARK-18070 binary operator should not consider nullability when comparing input types") { val rows = Seq(Row(Seq(1), Seq(1))) val schema = new StructType() @@ -2314,25 +1873,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } - } - test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) @@ -2517,24 +2057,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { - val df1 = spark.createDataFrame(Seq( - (1, 1) - )).toDF("a", "b").withColumn("c", newCol) - - val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) - checkAnswer(df2, result) - } - - check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) - check(lit(null).cast("int"), $"c".isNotNull, Seq()) - check(lit(2).cast("int"), $"c".isNull, Seq()) - check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" =!= 2, Seq()) - } - test("SPARK-25402 Null handling in BooleanSimplification") { val schema = StructType.fromDDL("a boolean, b int") val rows = Seq(Row(null, 1)) From 6bb60b30fd74b2c38640a4e54e5bb19eb890793e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 15:51:28 +0800 Subject: [PATCH 121/145] [SPARK-26168][SQL] Update the code comments in Expression and Aggregate ## What changes were proposed in this pull request? This PR is to improve the code comments to document some common traits and traps about the expression. ## How was this patch tested? N/A Closes #23135 from gatorsmile/addcomments. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 ++- .../sql/catalyst/expressions/Expression.scala | 44 +++++++++++++++---- .../expressions/namedExpressions.scala | 3 ++ .../plans/logical/basicLogicalOperators.scala | 16 ++++++- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72ac80e0a0a18..133fa119b7aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -181,8 +181,9 @@ object TypeCoercion { } /** - * The method finds a common type for data types that differ only in nullable, containsNull - * and valueContainsNull flags. If the input types are too different, None is returned. + * The method finds a common type for data types that differ only in nullable flags, including + * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]]. + * If the input types are different besides nullable flags, None is returned. */ def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { if (t1 == t2) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d51b11024a09d..2ecec61adb0ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf @@ -40,12 +41,28 @@ import org.apache.spark.sql.types._ * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. See [[Substring]] for an example. * - * There are a few important traits: + * There are a few important traits or abstract classes: * * - [[Nondeterministic]]: an expression that is not deterministic. + * - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID + * and Rand. A stateful expression is always non-deterministic. * - [[Unevaluable]]: an expression that is not supposed to be evaluated. * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to * interpreted mode. + * - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in + * null output). + * - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL + * expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`, + * and object `MapObjects` and `Invoke`. + * - [[UserDefinedExpression]]: a common base trait for user-defined functions, including + * UDF/UDAF/UDTF. + * - [[HigherOrderFunction]]: a common base trait for higher order functions that take one or more + * (lambda) functions and applies these to some objects. The function + * produces a number of variables which can be consumed by some lambda + * functions. + * - [[NamedExpression]]: An [[Expression]] that is named. + * - [[TimeZoneAwareExpression]]: A common base trait for time zone aware expressions. + * - [[SubqueryExpression]]: A base interface for expressions that contain a [[LogicalPlan]]. * * - [[LeafExpression]]: an expression that has no child. * - [[UnaryExpression]]: an expression that has one child. @@ -54,12 +71,20 @@ import org.apache.spark.sql.types._ * - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have * the same output data type. * + * A few important traits used for type coercion rules: + * - [[ExpectsInputTypes]]: an expression that has the expected input types. This trait is typically + * used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * - [[ImplicitCastInputTypes]]: an expression that has the expected input types, which can be + * implicitly castable using [[TypeCoercion.ImplicitTypeCasts]]. + * - [[ComplexTypeMergingExpression]]: to resolve output types of the complex expressions + * (e.g., [[CaseWhen]]). */ abstract class Expression extends TreeNode[Expression] { /** * Returns true when an expression is a candidate for static evaluation before the query is - * executed. + * executed. A typical use case: [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]] * * The following conditions are used to determine suitability for constant folding: * - A [[Coalesce]] is foldable if all of its children are foldable @@ -72,7 +97,8 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns true when the current expression always return the same result for fixed inputs from - * children. + * children. The non-deterministic expressions should not change in number and order. They should + * not be evaluated during the query planning. * * Note that this means that an expression should be considered as non-deterministic if: * - it relies on some mutable internal state, or @@ -252,8 +278,9 @@ abstract class Expression extends TreeNode[Expression] { /** - * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization - * time (e.g. Star). This trait is used by those expressions. + * An expression that cannot be evaluated. These expressions don't live past analysis or + * optimization time (e.g. Star) and should not be evaluated during query planning and + * execution. */ trait Unevaluable extends Expression { @@ -724,9 +751,10 @@ abstract class TernaryExpression extends Expression { } /** - * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. - * This logic is usually utilized by expressions combining data from multiple child expressions - * of non-primitive types (e.g. [[CaseWhen]]). + * A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]] + * and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date + * type. This is usually utilized by the expressions (e.g. [[CaseWhen]]) that combine data from + * multiple child expressions of non-primitive types. */ trait ComplexTypeMergingExpression extends Expression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 049ea77691395..02b48f9e30f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -130,6 +130,9 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * + * Note that when creating a new Alias, all the [[AttributeReference]] that refer to + * the original alias should be updated to the new one. + * * @param child The computation being performed * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 07fa17b233a47..a26ec4eed8648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString @@ -575,6 +575,18 @@ case class Range( } } +/** + * This is a Group by operator with the aggregate functions and projections. + * + * @param groupingExpressions expressions for grouping keys + * @param aggregateExpressions expressions for a project list, which could contain + * [[AggregateFunction]]s. + * + * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before + * separating projection from grouping and aggregate, we should avoid expression-level optimization + * on aggregateExpressions, which could reference an expression in groupingExpressions. + * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] + */ case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], From 1bb60ab8392adf8b896cc04fb1d060620cf09d8a Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Mon, 26 Nov 2018 05:57:33 -0600 Subject: [PATCH 122/145] [SPARK-26153][ML] GBT & RandomForest avoid unnecessary `first` job to compute `numFeatures` ## What changes were proposed in this pull request? use base models' `numFeature` instead of `first` job ## How was this patch tested? existing tests Closes #23123 from zhengruifeng/avoid_first_job. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 5 +++-- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 6 ++++-- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index fab8155add5a8..09a9df6d15ece 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") ( (convert2LabeledPoint(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") ( maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) - instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = if (withValidation) { @@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 05fff8885fbf2..0a3bfd1f85e08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") ( .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 186fa2399af05..9b386ef5eed8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -165,7 +165,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { (extractLabeledPoints(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) instr.logPipelineStage(this) @@ -173,7 +172,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) - instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, @@ -182,6 +180,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f5e668ca71db..afa9a646412b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) new RandomForestRegressionModel(uid, trees, numFeatures) } From 2512a1d42911370854ca42d987c851128fa0b263 Mon Sep 17 00:00:00 2001 From: Anastasios Zouzias Date: Mon, 26 Nov 2018 11:10:38 -0600 Subject: [PATCH 123/145] [SPARK-26121][STRUCTURED STREAMING] Allow users to define prefix of Kafka's consumer group (group.id) ## What changes were proposed in this pull request? Allow the Spark Structured Streaming user to specify the prefix of the consumer group (group.id), compared to force consumer group ids of the form `spark-kafka-source-*` ## How was this patch tested? Unit tests provided by Spark (backwards compatible change, i.e., user can optionally use the functionality) `mvn test -pl external/kafka-0-10` Closes #23103 from zouzias/SPARK-26121. Authored-by: Anastasios Zouzias Signed-off-by: cody koeninger --- .../structured-streaming-kafka-integration.md | 37 ++++++++++++------- .../sql/kafka010/KafkaSourceProvider.scala | 18 +++++++-- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 71fd5b10cc407..a549ce2a6a05f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -123,7 +123,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") -### Creating a Kafka Source for Batch Queries +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create a Dataset/DataFrame for a defined range of offsets. @@ -374,17 +374,24 @@ The following configurations are optional: streaming and batch Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume. + + groupIdPrefix + string + spark-kafka-source + streaming and batch + Prefix of consumer group identifiers (`group.id`) that are generated by structured streaming queries + ## Writing Data to Kafka -Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. -Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, if writing the query is successful, then you can assume that the query output was written at least once. A possible -solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key that can be used to perform de-duplication when reading. The Dataframe being written to Kafka should have the following columns in schema: @@ -405,8 +412,8 @@ The Dataframe being written to Kafka should have the following columns in schema \* The topic column is required if the "topic" configuration option is not specified.
    -The value column is the only required option. If a key column is not specified then -a ```null``` valued key column will be automatically added (see Kafka semantics on +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on how ```null``` valued key values are handled). If a topic column exists then its value is used as the topic when writing the given row to Kafka, unless the "topic" configuration option is set i.e., the "topic" configuration option overrides the topic column. @@ -568,7 +575,7 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .save() - + {% endhighlight %} @@ -576,23 +583,25 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ ## Kafka Specific Configurations -Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) for parameters related to writing data. Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: -- **group.id**: Kafka source will create a unique group id for each query automatically. +- **group.id**: Kafka source will create a unique group id for each query automatically. The user can +set the prefix of the automatically generated group.id's via the optional source option `groupIdPrefix`, default value +is "spark-kafka-source". - **auto.offset.reset**: Set the source option `startingOffsets` to specify - where to start instead. Structured Streaming manages which offsets are consumed internally, rather - than rely on the kafka Consumer to do it. This will ensure that no data is missed when new + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when new topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new streaming query is started, and that resuming will always pick up from where the query left off. -- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the keys. -- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 5034bd73d6e74..f770f0c2a04c2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -77,7 +77,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -119,7 +119,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -159,7 +159,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -538,6 +538,18 @@ private[kafka010] object KafkaSourceProvider extends Logging { .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) .build() + /** + * Returns a unique consumer group (group.id), allowing the user to set the prefix of + * the consumer group + */ + private def streamingUniqueGroupId( + parameters: Map[String, String], + metadataPath: String): String = { + val groupIdPrefix = parameters + .getOrElse("groupIdPrefix", "spark-kafka-source") + s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" + } + /** Class to conveniently update Kafka config params, while logging the changes */ private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { private val map = new ju.HashMap[String, Object](kafkaParams.asJava) From 3df307aa515b3564686e75d1b71754bbcaaf2dec Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Mon, 26 Nov 2018 11:06:02 -0800 Subject: [PATCH 124/145] [SPARK-25960][K8S] Support subpath mounting with Kubernetes ## What changes were proposed in this pull request? This PR adds configurations to use subpaths with Spark on k8s. Subpaths (https://kubernetes.io/docs/concepts/storage/volumes/#using-subpath) allow the user to specify a path within a volume to use instead of the volume's root. ## How was this patch tested? Added unit tests. Ran SparkPi on a cluster with event logging pointed at a subpath-mount and verified the driver host created and used the subpath. Closes #23026 from NiharS/k8s_subpath. Authored-by: Nihar Sheth Signed-off-by: Marcelo Vanzin --- docs/running-on-kubernetes.md | 17 ++++ .../org/apache/spark/deploy/k8s/Config.scala | 1 + .../deploy/k8s/KubernetesVolumeSpec.scala | 1 + .../deploy/k8s/KubernetesVolumeUtils.scala | 2 + .../features/MountVolumesFeatureStep.scala | 1 + .../k8s/KubernetesVolumeUtilsSuite.scala | 12 +++ .../MountVolumesFeatureStepSuite.scala | 79 +++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 34 ++++++++ .../k8s/KubernetesExecutorBuilderSuite.scala | 1 + 9 files changed, 148 insertions(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e940d9a63b7af..2c01e1e7155ef 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -245,6 +245,7 @@ To mount a volume of any of the types above into the driver pod, use the followi ``` --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath= ``` Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. @@ -806,6 +807,14 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the driver pod. + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly (none) @@ -830,6 +839,14 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the executor pod. + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly false diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index a32bd93bb65bc..724acd231a6cb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -297,6 +297,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath" val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index b1762d1efe2ea..1a214fad96618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -34,5 +34,6 @@ private[spark] case class KubernetesEmptyDirVolumeConf( private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( volumeName: String, mountPath: String, + mountSubPath: String, mountReadOnly: Boolean, volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 713df5fffc3a2..155326469235b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -39,6 +39,7 @@ private[spark] object KubernetesVolumeUtils { getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" for { path <- properties.getTry(pathKey) @@ -46,6 +47,7 @@ private[spark] object KubernetesVolumeUtils { } yield KubernetesVolumeSpec( volumeName = volumeName, mountPath = path, + mountSubPath = properties.get(subPathKey).getOrElse(""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = volumeConf ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index e60259c4a9b5a..1473a7d3ee7f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -51,6 +51,7 @@ private[spark] class MountVolumesFeatureStep( val volumeMount = new VolumeMountBuilder() .withMountPath(spec.mountPath) .withReadOnly(spec.mountReadOnly) + .withSubPath(spec.mountSubPath) .withName(spec.volumeName) .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index d795d159773a8..de79a58a3a756 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -33,6 +33,18 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { KubernetesHostPathVolumeConf("/hostPath")) } + test("Parses subPath correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountSubPath === "subPath") + } + test("Parses persistentVolumeClaim volumes correctly") { val sparkConf = new SparkConf(false) sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 2a957460ca8e0..aadbf16897f46 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -43,6 +43,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) @@ -62,6 +63,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -83,6 +85,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) @@ -104,6 +107,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -125,12 +129,14 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val hpVolumeConf = KubernetesVolumeSpec( "hpVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/checkpoints", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -142,4 +148,77 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(configuredPod.pod.getSpec.getVolumes.size() === 2) assert(configuredPod.container.getVolumeMounts.size() === 2) } + + test("Mounts subpath on emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "foo", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDirMount = configuredPod.container.getVolumeMounts.get(0) + assert(emptyDirMount.getMountPath === "/tmp") + assert(emptyDirMount.getName === "testVolume") + assert(emptyDirMount.getSubPath === "foo") + } + + test("Mounts subpath on persistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "bar", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + val pvcMount = configuredPod.container.getVolumeMounts.get(0) + assert(pvcMount.getMountPath === "/tmp") + assert(pvcMount.getName === "testVolume") + assert(pvcMount.getSubPath === "bar") + } + + test("Mounts multiple subpaths") { + val volumeConf = KubernetesEmptyDirVolumeConf(None, None) + val emptyDirSpec = KubernetesVolumeSpec( + "testEmptyDir", + "/tmp/foo", + "foo", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val pvcSpec = KubernetesVolumeSpec( + "testPVC", + "/tmp/bar", + "bar", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy( + roleVolumes = emptyDirSpec :: pvcSpec :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + val mounts = configuredPod.container.getVolumeMounts + assert(mounts.size() === 2) + assert(mounts.get(0).getName === "testEmptyDir") + assert(mounts.get(0).getMountPath === "/tmp/foo") + assert(mounts.get(0).getSubPath === "foo") + assert(mounts.get(1).getName === "testPVC") + assert(mounts.get(1).getMountPath === "/tmp/bar") + assert(mounts.get(1).getSubPath === "bar") + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index fe900fda6e545..3708864592d75 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -140,6 +140,40 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + JavaMainAppResource(None), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + hadoopConfSpec = None) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + DRIVER_CMD_STEP_TYPE) + } + + test("Apply volumes step if a mount subpath is present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + "foo", false, KubernetesHostPathVolumeConf("/path")) val conf = KubernetesConf( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 1fea08c37ccc6..a59f6d072023e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -107,6 +107,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/checkpoint")) val conf = KubernetesConf( From 76ef02e499db49c0c6a37fa9dff3d731aeac9898 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 26 Nov 2018 14:08:32 -0600 Subject: [PATCH 125/145] [SPARK-21809] Change Stage Page to use datatables to support sorting columns and searching Support column sort, pagination and search for Stage Page using jQuery DataTable and REST API. Before this commit, the Stage page generated a hard-coded HTML table that could not support search. Supporting search and sort (over all applications rather than the 20 entries in the current page) in any case will greatly improve the user experience. Created the stagespage-template.html for displaying application information in datables. Added REST api endpoint and javascript code to fetch data from the endpoint and display it on the data table. Because of the above change, certain functionalities in the page had to be modified to support the addition of datatables. For example, the toggle checkbox 'Select All' previously would add the checked fields as columns in the Task table and as rows in the Summary Metrics table, but after the change, only columns are added in the Task Table as it got tricky to add rows dynamically in the datatables. ## How was this patch tested? I have attached the screenshots of the Stage Page UI before and after the fix. **Before:** 30564304-35991e1c-9c8a-11e7-850f-2ac7a347f600 31360592-cbaa2bae-ad14-11e7-941d-95b4c7d14970 **After:** 31360591-c5650ee4-ad14-11e7-9665-5a08d8f21830 31360604-d266b6b0-ad14-11e7-94b5-dcc4bb5443f4 Closes #21688 from pgandhi999/SPARK-21809-2.3. Authored-by: pgandhi Signed-off-by: Thomas Graves --- .../ui/static/executorspage-template.html | 8 +- .../apache/spark/ui/static/executorspage.js | 84 +- .../spark/ui/static/images/sort_asc.png | Bin 0 -> 160 bytes .../ui/static/images/sort_asc_disabled.png | Bin 0 -> 148 bytes .../spark/ui/static/images/sort_both.png | Bin 0 -> 201 bytes .../spark/ui/static/images/sort_desc.png | Bin 0 -> 158 bytes .../ui/static/images/sort_desc_disabled.png | Bin 0 -> 146 bytes .../org/apache/spark/ui/static/stagepage.js | 958 ++++++++++++ .../spark/ui/static/stagespage-template.html | 124 ++ .../org/apache/spark/ui/static/utils.js | 113 +- .../spark/ui/static/webui-dataTables.css | 20 + .../org/apache/spark/ui/static/webui.css | 101 ++ .../apache/spark/status/AppStatusStore.scala | 26 +- .../spark/status/api/v1/StagesResource.scala | 121 +- .../org/apache/spark/status/api/v1/api.scala | 5 +- .../org/apache/spark/status/storeTypes.scala | 5 +- .../scala/org/apache/spark/ui/UIUtils.scala | 2 + .../apache/spark/ui/jobs/ExecutorTable.scala | 149 -- .../org/apache/spark/ui/jobs/StagePage.scala | 325 +---- .../blacklisting_for_stage_expectation.json | 1287 +++++++++-------- ...acklisting_node_for_stage_expectation.json | 112 +- .../one_stage_attempt_json_expectation.json | 40 +- .../one_stage_json_expectation.json | 40 +- .../stage_task_list_expectation.json | 100 +- ...multi_attempt_app_json_1__expectation.json | 40 +- ...multi_attempt_app_json_2__expectation.json | 40 +- ...k_list_w__offset___length_expectation.json | 250 +++- ...stage_task_list_w__sortBy_expectation.json | 100 +- ...tBy_short_names___runtime_expectation.json | 100 +- ...rtBy_short_names__runtime_expectation.json | 100 +- ...age_with_accumulable_json_expectation.json | 150 +- .../spark/status/AppStatusUtilsSuite.scala | 10 +- .../org/apache/spark/ui/StagePageSuite.scala | 12 - 33 files changed, 3064 insertions(+), 1358 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc_disabled.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_both.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc_disabled.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/stagepage.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html create mode 100644 core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 5c91304e49fd7..f2c17aef097a4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -16,10 +16,10 @@ --> diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js index 4f63f6413d6de..deeafad4eb5f5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/utils.js +++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js @@ -18,7 +18,7 @@ // this function works exactly the same as UIUtils.formatDuration function formatDuration(milliseconds) { if (milliseconds < 100) { - return milliseconds + " ms"; + return parseInt(milliseconds).toFixed(1) + " ms"; } var seconds = milliseconds * 1.0 / 1000; if (seconds < 1) { @@ -74,3 +74,114 @@ function getTimeZone() { return new Date().toString().match(/\((.*)\)/)[1]; } } + +function formatLogsCells(execLogs, type) { + if (type !== 'display') return Object.keys(execLogs); + if (!execLogs) return; + var result = ''; + $.each(execLogs, function (logName, logUrl) { + result += '' + }); + return result; +} + +function getStandAloneAppId(cb) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + // Looks like Web UI is running in standalone mode + // Let's get application-id using REST End Point + $.getJSON(location.origin + "/api/v1/applications", function(response, status, jqXHR) { + if (response && response.length > 0) { + var appId = response[0].id; + cb(appId); + return; + } + }); +} + +// This function is a helper function for sorting in datatable. +// When the data is in duration (e.g. 12ms 2s 2min 2h ) +// It will convert the string into integer for correct ordering +function ConvertDurationString(data) { + data = data.toString(); + var units = data.replace(/[\d\.]/g, '' ) + .replace(' ', '') + .toLowerCase(); + var multiplier = 1; + + switch(units) { + case 's': + multiplier = 1000; + break; + case 'min': + multiplier = 600000; + break; + case 'h': + multiplier = 3600000; + break; + default: + break; + } + return parseFloat(data) * multiplier; +} + +function createTemplateURI(appId, templateName) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var baseURI = words.slice(0, ind + 1).join('/') + '/' + appId + '/static/' + templateName + '-template.html'; + return baseURI; + } + ind = words.indexOf("history"); + if(ind > 0) { + var baseURI = words.slice(0, ind).join('/') + '/static/' + templateName + '-template.html'; + return baseURI; + } + return location.origin + "/static/" + templateName + "-template.html"; +} + +function setDataTableDefaults() { + $.extend($.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20, 40, 60, 100, -1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); +} + +function formatDate(date) { + if (date <= 0) return "-"; + else return date.split(".")[0].replace("T", " "); +} + +function createRESTEndPointForExecutorsPage(appId) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + var newBaseURI = words.slice(0, ind + 2).join('/'); + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors" + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + var attemptId = words[ind + 2]; + var newBaseURI = words.slice(0, ind).join('/'); + if (isNaN(attemptId)) { + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors"; + } else { + return newBaseURI + "/api/v1/applications/" + appId + "/" + attemptId + "/allexecutors"; + } + } + return location.origin + "/api/v1/applications/" + appId + "/allexecutors"; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css new file mode 100644 index 0000000000000..f6b4abed21e0d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +table.dataTable thead .sorting_asc { background: url('images/sort_asc.png') no-repeat bottom right; } + +table.dataTable thead .sorting_desc { background: url('images/sort_desc.png') no-repeat bottom right; } \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 266eeec55576e..fe5bb25687af1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -260,4 +260,105 @@ a.expandbutton { .paginate_button.active > a { color: #999999; text-decoration: underline; +} + +.title-table { + clear: left; + display: inline-block; +} + +.table-dataTable { + width: 100%; +} + +.container-fluid-div { + width: 200px; +} + +.scheduler-delay-checkbox-div { + width: 120px; +} + +.task-deserialization-time-checkbox-div { + width: 175px; +} + +.shuffle-read-blocked-time-checkbox-div { + width: 187px; +} + +.shuffle-remote-reads-checkbox-div { + width: 157px; +} + +.result-serialization-time-checkbox-div { + width: 171px; +} + +.getting-result-time-checkbox-div { + width: 141px; +} + +.peak-execution-memory-checkbox-div { + width: 170px; +} + +#active-tasks-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-tasks-table th:first-child { + border-left: 1px solid #dddddd; +} + +#accumulator-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#accumulator-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-executor-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-executor-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-metrics-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-metrics-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-execs-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-execs-table th:first-child { + border-left: 1px solid #dddddd; +} + +#active-executors-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-executors-table th:first-child { + border-left: 1px solid #dddddd; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 63b9d8988499d..5c0ed4d5d8f4c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -351,7 +351,9 @@ private[spark] class AppStatusStore( def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.toApi).toSeq.reverse + .max(maxTasks).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq.reverse } def taskList( @@ -390,7 +392,9 @@ private[spark] class AppStatusStore( } val ordered = if (ascending) indexed else indexed.reverse() - ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + ordered.skip(offset).max(length).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq } def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { @@ -496,6 +500,24 @@ private[spark] class AppStatusStore( store.close() } + def constructTaskData(taskDataWrapper: TaskDataWrapper) : v1.TaskData = { + val taskDataOld: v1.TaskData = taskDataWrapper.toApi + val executorLogs: Option[Map[String, String]] = try { + Some(executorSummary(taskDataOld.executorId).executorLogs) + } catch { + case e: NoSuchElementException => e.getMessage + None + } + new v1.TaskData(taskDataOld.taskId, taskDataOld.index, + taskDataOld.attempt, taskDataOld.launchTime, taskDataOld.resultFetchStart, + taskDataOld.duration, taskDataOld.executorId, taskDataOld.host, taskDataOld.status, + taskDataOld.taskLocality, taskDataOld.speculative, taskDataOld.accumulatorUpdates, + taskDataOld.errorMessage, taskDataOld.taskMetrics, + executorLogs.getOrElse(Map[String, String]()), + AppStatusUtils.schedulerDelay(taskDataOld), + AppStatusUtils.gettingResultTime(taskDataOld)) + } + } private[spark] object AppStatusStore { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 30d52b97833e6..f81892734c2de 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -16,15 +16,16 @@ */ package org.apache.spark.status.api.v1 -import java.util.{List => JList} +import java.util.{HashMap, List => JList, Locale} import javax.ws.rs._ -import javax.ws.rs.core.MediaType +import javax.ws.rs.core.{Context, MediaType, MultivaluedMap, UriInfo} import org.apache.spark.SparkException import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.ApiHelper._ @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { @@ -102,4 +103,120 @@ private[v1] class StagesResource extends BaseAppResource { withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) } + // This api needs to stay formatted exactly as it is below, since, it is being used by the + // datatables for the stages page. + @GET + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskTable") + def taskTable( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean, + @Context uriInfo: UriInfo): + HashMap[String, Object] = { + withUI { ui => + val uriQueryParameters = uriInfo.getQueryParameters(true) + val totalRecords = uriQueryParameters.getFirst("numTasks") + var isSearch = false + var searchValue: String = null + var filteredRecords = totalRecords + // The datatables client API sends a list of query parameters to the server which contain + // information like the columns to be sorted, search value typed by the user in the search + // box, pagination index etc. For more information on these query parameters, + // refer https://datatables.net/manual/server-side. + if (uriQueryParameters.getFirst("search[value]") != null && + uriQueryParameters.getFirst("search[value]").length > 0) { + isSearch = true + searchValue = uriQueryParameters.getFirst("search[value]") + } + val _tasksToShow: Seq[TaskData] = doPagination(uriQueryParameters, stageId, stageAttemptId, + isSearch, totalRecords.toInt) + val ret = new HashMap[String, Object]() + if (_tasksToShow.nonEmpty) { + // Performs server-side search based on input from user + if (isSearch) { + val filteredTaskList = filterTaskList(_tasksToShow, searchValue) + filteredRecords = filteredTaskList.length.toString + if (filteredTaskList.length > 0) { + val pageStartIndex = uriQueryParameters.getFirst("start").toInt + val pageLength = uriQueryParameters.getFirst("length").toInt + ret.put("aaData", filteredTaskList.slice( + pageStartIndex, pageStartIndex + pageLength)) + } else { + ret.put("aaData", filteredTaskList) + } + } else { + ret.put("aaData", _tasksToShow) + } + } else { + ret.put("aaData", _tasksToShow) + } + ret.put("recordsTotal", totalRecords) + ret.put("recordsFiltered", filteredRecords) + ret + } + } + + // Performs pagination on the server side + def doPagination(queryParameters: MultivaluedMap[String, String], stageId: Int, + stageAttemptId: Int, isSearch: Boolean, totalRecords: Int): Seq[TaskData] = { + var columnNameToSort = queryParameters.getFirst("columnNameToSort") + // Sorting on Logs column will default to Index column sort + if (columnNameToSort.equalsIgnoreCase("Logs")) { + columnNameToSort = "Index" + } + val isAscendingStr = queryParameters.getFirst("order[0][dir]") + var pageStartIndex = 0 + var pageLength = totalRecords + // We fetch only the desired rows upto the specified page length for all cases except when a + // search query is present, in that case, we need to fetch all the rows to perform the search + // on the entire table + if (!isSearch) { + pageStartIndex = queryParameters.getFirst("start").toInt + pageLength = queryParameters.getFirst("length").toInt + } + withUI(_.store.taskList(stageId, stageAttemptId, pageStartIndex, pageLength, + indexName(columnNameToSort), isAscendingStr.equalsIgnoreCase("asc"))) + } + + // Filters task list based on search parameter + def filterTaskList( + taskDataList: Seq[TaskData], + searchValue: String): Seq[TaskData] = { + val defaultOptionString: String = "d" + val searchValueLowerCase = searchValue.toLowerCase(Locale.ROOT) + val containsValue = (taskDataParams: Any) => taskDataParams.toString.toLowerCase( + Locale.ROOT).contains(searchValueLowerCase) + val taskMetricsContainsValue = (task: TaskData) => task.taskMetrics match { + case None => false + case Some(metrics) => + (containsValue(task.taskMetrics.get.executorDeserializeTime) + || containsValue(task.taskMetrics.get.executorRunTime) + || containsValue(task.taskMetrics.get.jvmGcTime) + || containsValue(task.taskMetrics.get.resultSerializationTime) + || containsValue(task.taskMetrics.get.memoryBytesSpilled) + || containsValue(task.taskMetrics.get.diskBytesSpilled) + || containsValue(task.taskMetrics.get.peakExecutionMemory) + || containsValue(task.taskMetrics.get.inputMetrics.bytesRead) + || containsValue(task.taskMetrics.get.inputMetrics.recordsRead) + || containsValue(task.taskMetrics.get.outputMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.outputMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.recordsRead) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.writeTime)) + } + val filteredTaskDataSequence: Seq[TaskData] = taskDataList.filter(f => + (containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt) + || containsValue(f.launchTime) + || containsValue(f.resultFetchStart.getOrElse(defaultOptionString)) + || containsValue(f.duration.getOrElse(defaultOptionString)) + || containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status) + || containsValue(f.taskLocality) || containsValue(f.speculative) + || containsValue(f.errorMessage.getOrElse(defaultOptionString)) + || taskMetricsContainsValue(f) + || containsValue(f.schedulerDelay) || containsValue(f.gettingResultTime))) + filteredTaskDataSequence + } + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 30afd8b769720..aa21da2b66ab2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -253,7 +253,10 @@ class TaskData private[spark]( val speculative: Boolean, val accumulatorUpdates: Seq[AccumulableInfo], val errorMessage: Option[String] = None, - val taskMetrics: Option[TaskMetrics] = None) + val taskMetrics: Option[TaskMetrics] = None, + val executorLogs: Map[String, String], + val schedulerDelay: Long, + val gettingResultTime: Long) class TaskMetrics private[spark]( val executorDeserializeTime: Long, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 646cf25880e37..ef19e86f3135f 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -283,7 +283,10 @@ private[spark] class TaskDataWrapper( speculative, accumulatorUpdates, errorMessage, - metrics) + metrics, + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) } @JsonIgnore @KVIndex(TaskIndexNames.STAGE) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 3aed4647a96f0..60a929375baae 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -204,6 +204,8 @@ private[spark] object UIUtils extends Logging { href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala deleted file mode 100644 index 1be81e5ef9952..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import scala.xml.{Node, Unparsed} - -import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.StageData -import org.apache.spark.ui.{ToolTips, UIUtils} -import org.apache.spark.util.Utils - -/** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { - - import ApiHelper._ - - def toNodeSeq: Seq[Node] = { - - - - - - - - - - {if (hasInput(stage)) { - - }} - {if (hasOutput(stage)) { - - }} - {if (hasShuffleRead(stage)) { - - }} - {if (hasShuffleWrite(stage)) { - - }} - {if (hasBytesSpilled(stage)) { - - - }} - - - - - {createExecutorTable(stage)} - -
    Executor IDAddressTask TimeTotal TasksFailed TasksKilled TasksSucceeded Tasks - Input Size / Records - - Output Size / Records - - - Shuffle Read Size / Records - - - Shuffle Write Size / Records - Shuffle Spill (Memory)Shuffle Spill (Disk) - - Blacklisted - - Logs
    - - } - - private def createExecutorTable(stage: StageData) : Seq[Node] = { - val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) - - executorSummary.toSeq.sortBy(_._1).map { case (k, v) => - val executor = store.asOption(store.executorSummary(k)) - - {k} - {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} - {v.failedTasks} - {v.killedTasks} - {v.succeededTasks} - {if (hasInput(stage)) { - - {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} - - }} - {if (hasOutput(stage)) { - - {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} - - }} - {if (hasShuffleRead(stage)) { - - {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} - - }} - {if (hasShuffleWrite(stage)) { - - {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} - - }} - {if (hasBytesSpilled(stage)) { - - {Utils.bytesToString(v.memoryBytesSpilled)} - - - {Utils.bytesToString(v.diskBytesSpilled)} - - }} - { - if (executor.map(_.isBlacklisted).getOrElse(false)) { - for application - } else if (v.isBlacklistedForStage) { - for stage - } else { - false - } - } - {executor.map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - }} - - - - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b436b9234144..a213b764abea7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -92,6 +92,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val eventTimelineParameterTaskPage = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageNumber")) + val eventTimelineParameterTaskPageSize = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageSize")) + var eventTimelineTaskPage = Option(eventTimelineParameterTaskPage).map(_.toInt).getOrElse(1) + var eventTimelineTaskPageSize = Option( + eventTimelineParameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => UIUtils.decodeURLParameter(sortColumn) @@ -132,6 +140,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } else { s"$totalTasks, showing $storedTasks" } + if (eventTimelineTaskPageSize < 1 || eventTimelineTaskPageSize > totalTasks) { + eventTimelineTaskPageSize = totalTasks + } + val eventTimelineTotalPages = + (totalTasks + eventTimelineTaskPageSize - 1) / eventTimelineTaskPageSize + if (eventTimelineTaskPage < 1 || eventTimelineTaskPage > eventTimelineTotalPages) { + eventTimelineTaskPage = 1 + } val summary =
    @@ -193,73 +209,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    - val showAdditionalMetrics = -
    - - - Show Additional Metrics - - -
    - val stageGraph = parent.store.asOption(parent.store.operationGraphForStage(stageId)) val dagViz = UIUtils.showDagVizForStage(stageId, stageGraph) @@ -277,7 +226,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData.accumulatorUpdates.toSeq) val currentTime = System.currentTimeMillis() - val (taskTable, taskTableHTML) = try { + val taskTable = try { val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + @@ -288,17 +237,10 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We desc = taskSortDesc, store = parent.store ) - (_taskTable, _taskTable.table(taskPage)) + _taskTable } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - val errorMessage = -
    -

    Error while rendering stage table:

    -
    -              {Utils.exceptionString(e)}
    -            
    -
    - (null, errorMessage) + null } val jsForScrollingDownToTaskTable = @@ -316,190 +258,36 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } - val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, - Array(0, 0.25, 0.5, 0.75, 1.0)) - - val summaryTable = metricsSummary.map { metrics => - def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - - def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { size => - {Utils.bytesToString(size.toLong)} - } - } - - def sizeQuantilesWithRecords( - data: IndexedSeq[Double], - records: IndexedSeq[Double]) : Seq[Node] = { - data.zip(records).map { case (d, r) => - {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} - } - } - - def titleCell(title: String, tooltip: String): Seq[Node] = { - - - {title} - - - } - - def simpleTitleCell(title: String): Seq[Node] = {title} - - val deserializationQuantiles = titleCell("Task Deserialization Time", - ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - - val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - - val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - - val serializationQuantiles = titleCell("Result Serialization Time", - ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - - val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ - timeQuantiles(metrics.gettingResultTime) - - val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", - ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) - - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ - timeQuantiles(metrics.schedulerDelay) - - def inputQuantiles: Seq[Node] = { - simpleTitleCell("Input Size / Records") ++ - sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) - } - - def outputQuantiles: Seq[Node] = { - simpleTitleCell("Output Size / Records") ++ - sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten) - } - - def shuffleReadBlockedQuantiles: Seq[Node] = { - titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ - timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) - } - - def shuffleReadTotalQuantiles: Seq[Node] = { - titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ - sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, - metrics.shuffleReadMetrics.readRecords) - } - - def shuffleReadRemoteQuantiles: Seq[Node] = { - titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ - sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) - } - - def shuffleWriteQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle Write Size / Records") ++ - sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, - metrics.shuffleWriteMetrics.writeRecords) - } - - def memoryBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) - } - - def diskBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) - } - - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", - "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false) - } - - val executorTable = new ExecutorTable(stageData, parent.store) - - val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() - - val aggMetrics = - -

    - - Aggregated Metrics by Executor -

    -
    -
    - {executorTable.toNodeSeq} -
    - val content = summary ++ - dagViz ++ - showAdditionalMetrics ++ + dagViz ++
    ++ makeTimeline( // Only show the tasks in the table - Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), - currentTime) ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ - aggMetrics ++ - maybeAccumulableTable ++ - -

    - - Tasks ({totalTasksNumStr}) -

    -
    ++ -
    - {taskTableHTML ++ jsForScrollingDownToTaskTable} -
    - UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) + Option(taskTable).map({ taskPagedTable => + val from = (eventTimelineTaskPage - 1) * eventTimelineTaskPageSize + val to = taskPagedTable.dataSource.dataSize.min( + eventTimelineTaskPage * eventTimelineTaskPageSize) + taskPagedTable.dataSource.sliceData(from, to)}).getOrElse(Nil), currentTime, + eventTimelineTaskPage, eventTimelineTaskPageSize, eventTimelineTotalPages, stageId, + stageAttemptId, totalTasks) ++ +
    + + +
    + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true, + useDataTables = true) + } - def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { + def makeTimeline( + tasks: Seq[TaskData], + currentTime: Long, + page: Int, + pageSize: Int, + totalPages: Int, + stageId: Int, + stageAttemptId: Int, + totalTasks: Int): Seq[Node] = { val executorsSet = new HashSet[(String, String)] var minLaunchTime = Long.MaxValue var maxFinishTime = Long.MinValue @@ -658,6 +446,31 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We Enable zooming +
    +
    + + + + + + + + + + +
    +
    {TIMELINE_LEGEND} ++ @@ -959,7 +772,7 @@ private[ui] class TaskPagedTable( } } -private[ui] object ApiHelper { +private[spark] object ApiHelper { val HEADER_ID = "ID" val HEADER_TASK_INDEX = "Index" diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json index 5e9e8230e2745..62e5c123fd3d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -1,639 +1,708 @@ { - "status": "COMPLETE", - "stageId": 0, - "attemptId": 0, - "numTasks": 10, - "numActiveTasks": 0, - "numCompleteTasks": 10, - "numFailedTasks": 2, - "numKilledTasks": 0, - "numCompletedIndices": 10, - "executorRunTime": 761, - "executorCpuTime": 269916000, - "submissionTime": "2018-01-09T10:21:18.152GMT", - "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", - "completionTime": "2018-01-09T10:21:19.062GMT", - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleReadBytes": 0, - "shuffleReadRecords": 0, - "shuffleWriteBytes": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "name": "map at :26", - "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", - "schedulingPool": "default", - "rddIds": [ - 1, - 0 - ], - "accumulatorUpdates": [], - "tasks": { - "0": { - "taskId": 0, - "index": 0, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.347GMT", - "duration": 562, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 460, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 14, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 3873006, - "recordsWritten": 0 + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 2, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 761, + "executorCpuTime" : 269916000, + "submissionTime" : "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime" : "2018-01-09T10:21:18.347GMT", + "completionTime" : "2018-01-09T10:21:19.062GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :26", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.347GMT", + "duration" : 562, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 460, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 14, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3873006, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 102, + "gettingResultTime" : 0 }, - "5": { - "taskId": 5, - "index": 3, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.958GMT", - "duration": 22, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2586000, - "executorRunTime": 9, - "executorCpuTime": 9635000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 262919, - "recordsWritten": 1 + "5" : { + "taskId" : 5, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.958GMT", + "duration" : 22, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2586000, + "executorRunTime" : 9, + "executorCpuTime" : 9635000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 262919, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, - "10": { - "taskId": 10, - "index": 8, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.034GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1803000, - "executorRunTime": 6, - "executorCpuTime": 6157000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 243647, - "recordsWritten": 1 + "10" : { + "taskId" : 10, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.034GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1803000, + "executorRunTime" : 6, + "executorCpuTime" : 6157000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 243647, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "1": { - "taskId": 1, - "index": 1, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.364GMT", - "duration": 565, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 301, - "executorDeserializeCpuTime": 200029000, - "executorRunTime": 212, - "executorCpuTime": 198479000, - "resultSize": 1115, - "jvmGcTime": 13, - "resultSerializationTime": 1, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 2409488, - "recordsWritten": 1 + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.364GMT", + "duration" : 565, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 301, + "executorDeserializeCpuTime" : 200029000, + "executorRunTime" : 212, + "executorCpuTime" : 198479000, + "resultSize" : 1115, + "jvmGcTime" : 13, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 2409488, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 51, + "gettingResultTime" : 0 }, - "6": { - "taskId": 6, - "index": 4, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.980GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2610000, - "executorRunTime": 10, - "executorCpuTime": 9622000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 385110, - "recordsWritten": 1 + "6" : { + "taskId" : 6, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.980GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2610000, + "executorRunTime" : 10, + "executorCpuTime" : 9622000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 385110, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "9": { - "taskId": 9, - "index": 7, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.022GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1981000, - "executorRunTime": 7, - "executorCpuTime": 6335000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 259354, - "recordsWritten": 1 + "9" : { + "taskId" : 9, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.022GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1981000, + "executorRunTime" : 7, + "executorCpuTime" : 6335000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 259354, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "2": { - "taskId": 2, - "index": 2, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.899GMT", - "duration": 27, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 16, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 126128, - "recordsWritten": 0 + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.899GMT", + "duration" : 27, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 126128, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, - "7": { - "taskId": 7, - "index": 5, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.996GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2231000, - "executorRunTime": 9, - "executorCpuTime": 8407000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 205520, - "recordsWritten": 1 + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.996GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2231000, + "executorRunTime" : 9, + "executorCpuTime" : 8407000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 205520, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "3": { - "taskId": 3, - "index": 0, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.919GMT", - "duration": 24, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 8, - "executorDeserializeCpuTime": 8878000, - "executorRunTime": 10, - "executorCpuTime": 9364000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 207014, - "recordsWritten": 1 + "3" : { + "taskId" : 3, + "index" : 0, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.919GMT", + "duration" : 24, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 8878000, + "executorRunTime" : 10, + "executorCpuTime" : 9364000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 207014, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "11": { - "taskId": 11, - "index": 9, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.045GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2017000, - "executorRunTime": 6, - "executorCpuTime": 6676000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 233652, - "recordsWritten": 1 + "11" : { + "taskId" : 11, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.045GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2017000, + "executorRunTime" : 6, + "executorCpuTime" : 6676000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 233652, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "8": { - "taskId": 8, - "index": 6, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.011GMT", - "duration": 11, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 1, - "executorDeserializeCpuTime": 1554000, - "executorRunTime": 7, - "executorCpuTime": 6034000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 213296, - "recordsWritten": 1 + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.011GMT", + "duration" : 11, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 1554000, + "executorRunTime" : 7, + "executorCpuTime" : 6034000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 213296, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "4": { - "taskId": 4, - "index": 2, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.943GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2211000, - "executorRunTime": 9, - "executorCpuTime": 9207000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 292381, - "recordsWritten": 1 + "4" : { + "taskId" : 4, + "index" : 2, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.943GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2211000, + "executorRunTime" : 9, + "executorCpuTime" : 9207000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 292381, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } }, - "executorSummary": { - "0": { - "taskTime": 589, - "failedTasks": 2, - "succeededTasks": 0, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 0, - "shuffleWriteRecords": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": true + "executorSummary" : { + "0" : { + "taskTime" : 589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true }, - "1": { - "taskTime": 708, - "failedTasks": 0, - "succeededTasks": 10, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": false + "1" : { + "taskTime" : 708, + "failedTasks" : 0, + "succeededTasks" : 10, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, - "killedTasksSummary": {} + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json index acd4cc53de6cd..6e46c881b2a21 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -74,7 +74,13 @@ "writeTime" : 3662221, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 68, + "gettingResultTime" : 0 }, "5" : { "taskId" : 5, @@ -122,7 +128,13 @@ "writeTime" : 191901, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 46, + "gettingResultTime" : 0 }, "10" : { "taskId" : 10, @@ -169,7 +181,13 @@ "writeTime" : 301705, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 50, + "gettingResultTime" : 0 }, "1" : { "taskId" : 1, @@ -217,7 +235,13 @@ "writeTime" : 3075188, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 174, + "gettingResultTime" : 0 }, "6" : { "taskId" : 6, @@ -265,7 +289,13 @@ "writeTime" : 183718, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -312,7 +342,13 @@ "writeTime" : 366050, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 42, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -359,7 +395,13 @@ "writeTime" : 369513, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 26, + "gettingResultTime" : 0 }, "2" : { "taskId" : 2, @@ -406,7 +448,13 @@ "writeTime" : 3322956, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 74, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -453,7 +501,13 @@ "writeTime" : 319101, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, "7" : { "taskId" : 7, @@ -500,7 +554,13 @@ "writeTime" : 377601, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, "3" : { "taskId" : 3, @@ -547,7 +607,13 @@ "writeTime" : 3587839, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 63, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -594,7 +660,13 @@ "writeTime" : 323898, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 12, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -641,7 +713,13 @@ "writeTime" : 311940, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 84, + "gettingResultTime" : 0 }, "4" : { "taskId" : 4, @@ -689,7 +767,13 @@ "writeTime" : 16858066, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 338, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 03f886afa5413..aa9471301fe3e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 947c89906955d..584803b5e8631 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index a15ee23523365..f859ab6fff240 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -89,7 +92,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -135,7 +141,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -181,7 +190,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -273,7 +288,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -319,7 +337,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -365,7 +386,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -411,7 +435,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -457,7 +484,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 10, "index" : 10, @@ -503,7 +533,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -549,7 +582,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -595,7 +631,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -641,7 +680,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -687,7 +729,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -733,7 +778,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -779,7 +827,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -825,7 +876,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -871,7 +925,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -917,5 +974,8 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index f9182b1658334..ea88ca116707a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 76dd2f710b90f..efd0a45bf01d0 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 6bdc10465d89e..d83528d84972c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -89,7 +92,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -135,7 +141,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -181,7 +190,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -227,7 +239,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -273,7 +288,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -319,7 +337,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -365,7 +386,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -411,7 +435,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -457,7 +484,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -503,7 +533,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -595,7 +631,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 23, "index" : 23, @@ -641,7 +680,10 @@ "writeTime" : 91844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 24, "index" : 24, @@ -687,7 +729,10 @@ "writeTime" : 157194, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 25, "index" : 25, @@ -733,7 +778,10 @@ "writeTime" : 94134, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 26, "index" : 26, @@ -779,7 +827,10 @@ "writeTime" : 108213, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 27, "index" : 27, @@ -825,7 +876,10 @@ "writeTime" : 102019, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 28, "index" : 28, @@ -871,7 +925,10 @@ "writeTime" : 104299, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, { "taskId" : 29, "index" : 29, @@ -917,7 +974,10 @@ "writeTime" : 114938, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 30, "index" : 30, @@ -963,7 +1023,10 @@ "writeTime" : 119770, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 31, "index" : 31, @@ -1009,7 +1072,10 @@ "writeTime" : 92619, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -1055,7 +1121,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 33, "index" : 33, @@ -1101,7 +1170,10 @@ "writeTime" : 118329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 34, "index" : 34, @@ -1147,7 +1219,10 @@ "writeTime" : 127746, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 35, "index" : 35, @@ -1193,7 +1268,10 @@ "writeTime" : 160963, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 36, "index" : 36, @@ -1239,7 +1317,10 @@ "writeTime" : 123855, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 37, "index" : 37, @@ -1285,7 +1366,10 @@ "writeTime" : 111869, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 38, "index" : 38, @@ -1331,7 +1415,10 @@ "writeTime" : 131158, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -1377,7 +1464,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 40, "index" : 40, @@ -1423,7 +1513,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -1469,7 +1562,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -1515,7 +1611,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -1561,7 +1660,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -1607,7 +1709,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 45, "index" : 45, @@ -1653,7 +1758,10 @@ "writeTime" : 92985, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 46, "index" : 46, @@ -1699,7 +1807,10 @@ "writeTime" : 113322, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -1745,7 +1856,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 48, "index" : 48, @@ -1791,7 +1905,10 @@ "writeTime" : 139844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 49, "index" : 49, @@ -1837,7 +1954,10 @@ "writeTime" : 94984, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -1883,7 +2003,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -1929,7 +2052,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -1975,7 +2101,10 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 53, "index" : 53, @@ -2021,7 +2150,10 @@ "writeTime" : 92835, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 54, "index" : 54, @@ -2067,7 +2199,10 @@ "writeTime" : 90506, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 55, "index" : 55, @@ -2113,7 +2248,10 @@ "writeTime" : 108309, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 56, "index" : 56, @@ -2159,7 +2297,10 @@ "writeTime" : 90329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -2205,7 +2346,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -2251,7 +2395,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -2297,5 +2444,8 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 09857cb401acd..01eef1b565bf6 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -89,7 +92,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -135,7 +141,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -181,7 +190,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -227,7 +239,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 68, "index" : 68, @@ -273,7 +288,10 @@ "writeTime" : 101750, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 86, "index" : 86, @@ -319,7 +337,10 @@ "writeTime" : 95848, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -365,7 +386,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -411,7 +435,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -457,7 +484,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -503,7 +533,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -549,7 +582,10 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 63, "index" : 63, @@ -595,7 +631,10 @@ "writeTime" : 102779, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 87, "index" : 87, @@ -641,7 +680,10 @@ "writeTime" : 102159, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 90, "index" : 90, @@ -687,7 +729,10 @@ "writeTime" : 98472, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 99, "index" : 99, @@ -733,7 +778,10 @@ "writeTime" : 133964, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -779,7 +827,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -825,7 +876,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -871,7 +925,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -917,5 +974,8 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 963f010968b62..a8e1fd303a42a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -83,14 +83,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, - "1" : { - "taskId" : 1, - "index" : 1, + "5" : { + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 53, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -99,11 +102,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "247", - "value" : "2175" + "update" : "897", + "value" : "3750" } ], "taskMetrics" : { - "executorDeserializeTime" : 14, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -135,14 +138,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, - "2" : { - "taskId" : 2, - "index" : 2, + "1" : { + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 48, + "launchTime" : "2015-03-16T19:25:36.521GMT", + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -151,11 +157,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "378", - "value" : "378" + "update" : "247", + "value" : "2175" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 14, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -187,14 +193,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "3" : { - "taskId" : 3, - "index" : 3, + "6" : { + "taskId" : 6, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 50, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -203,11 +212,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "572", - "value" : "950" + "update" : "978", + "value" : "1928" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -239,14 +248,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "4" : { - "taskId" : 4, - "index" : 4, + "2" : { + "taskId" : 2, + "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 52, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -255,17 +267,17 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "678", - "value" : "2853" + "update" : "378", + "value" : "378" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -291,14 +303,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, - "5" : { - "taskId" : 5, - "index" : 5, + "7" : { + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 52, + "launchTime" : "2015-03-16T19:25:36.524GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -307,8 +322,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "897", - "value" : "3750" + "update" : "1222", + "value" : "4972" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -343,14 +358,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "6" : { - "taskId" : 6, - "index" : 6, + "3" : { + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -359,11 +377,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "978", - "value" : "1928" + "update" : "572", + "value" : "950" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -395,14 +413,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, - "7" : { - "taskId" : 7, - "index" : 7, + "4" : { + "taskId" : 4, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -411,8 +432,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "1222", - "value" : "4972" + "update" : "678", + "value" : "2853" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -421,7 +442,7 @@ "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -447,7 +468,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala index 9e74e86ad54b9..a01b24d323d28 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -52,7 +52,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) val finishedTask = new TaskData( @@ -83,7 +86,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 2945c3ee0a9d9..5e976ae4e91da 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -96,18 +96,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { } } - test("peak execution memory should displayed") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - val targetString = "peak execution memory" - assert(html.contains(targetString)) - } - - test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - // verify min/25/50/75/max show task value not cumulative values - assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) - } - /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. From fbf62b7100be992cbc4eb67e154682db6c91e60e Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 26 Nov 2018 13:13:06 -0800 Subject: [PATCH 126/145] [SPARK-25451][SPARK-26100][CORE] Aggregated metrics table doesn't show the right number of the total tasks Total tasks in the aggregated table and the tasks table are not matching some times in the WEBUI. We need to force update the executor summary of the particular executorId, when ever last task of that executor has reached. Currently it force update based on last task on the stage end. So, for some particular executorId task might miss at the stage end. Tests to reproduce: ``` bin/spark-shell --master yarn --conf spark.executor.instances=3 sc.parallelize(1 to 10000, 10).map{ x => throw new RuntimeException("Bad executor")}.collect() ``` Before patch: ![screenshot from 2018-11-15 02-24-05](https://user-images.githubusercontent.com/23054875/48511776-b0d36480-e87d-11e8-89a8-ab97216e2c21.png) After patch: ![screenshot from 2018-11-15 02-32-38](https://user-images.githubusercontent.com/23054875/48512141-c39a6900-e87e-11e8-8535-903e1d11d13e.png) Closes #23038 from shahidki31/SPARK-25451. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 19 +++++++- .../org/apache/spark/status/LiveEntity.scala | 2 + .../spark/status/AppStatusListenerSuite.scala | 45 +++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 81d39e0407fed..8e845573a903d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -473,6 +473,7 @@ private[spark] class AppStatusListener( val locality = event.taskInfo.taskLocality.toString() val count = stage.localitySummary.getOrElse(locality, 0L) + 1L stage.localitySummary = stage.localitySummary ++ Map(locality -> count) + stage.activeTasksPerExecutor(event.taskInfo.executorId) += 1 maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -558,6 +559,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } + stage.activeTasksPerExecutor(event.taskInfo.executorId) -= 1 // [SPARK-24415] Wait for all tasks to finish before removing stage from live list val removeStage = stage.activeTasks == 0 && @@ -582,7 +584,11 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) } - conditionalLiveUpdate(job, now, removeStage) + if (removeStage) { + update(job, now) + } else { + maybeUpdate(job, now) + } } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -593,7 +599,16 @@ private[spark] class AppStatusListener( if (metricsDelta != null) { esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } - conditionalLiveUpdate(esummary, now, removeStage) + + val isLastTask = stage.activeTasksPerExecutor(event.taskInfo.executorId) == 0 + + // If the last task of the executor finished, then update the esummary + // for both live and history events. + if (isLastTask) { + update(esummary, now) + } else { + maybeUpdate(esummary, now) + } if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { stage.cleaning = true diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 80663318c1ba1..47e45a66ecccb 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -376,6 +376,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + val activeTasksPerExecutor = new HashMap[String, Int]().withDefaultValue(0) + var blackListedExecutors = new HashSet[String]() // Used for cleanup of tasks after they reach the configured limit. Not written to the store. diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 5f757b757ac61..1c787ff43b9ac 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1273,6 +1273,51 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(allJobs.head.numFailedStages == 1) } + test("SPARK-25451: total tasks in the executor summary should match total stage tasks") { + val testConf = conf.clone.set(LIVE_ENTITY_UPDATE_PERIOD, Long.MaxValue) + + val listener = new AppStatusListener(store, testConf, true) + + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + val tasks = createTasks(4, Array("1", "2")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } + + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + Success, tasks(0), null)) + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + Success, tasks(1), null)) + + stage.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobFailed(new RuntimeException("Bad Executor")))) + + time += 1 + tasks(2).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) + time += 1 + tasks(3).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) + + val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) + esummary.foreach { execSummary => + assert(execSummary.failedTasks === 1) + assert(execSummary.succeededTasks === 1) + assert(execSummary.killedTasks === 0) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 6f1a1c1248e0341a690aee655af05da9e9cbff90 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Nov 2018 14:37:41 -0800 Subject: [PATCH 127/145] [SPARK-25451][HOTFIX] Call stage.attemptNumber instead of attemptId. Closes #23149 from vanzin/SPARK-25451.hotfix. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/status/AppStatusListenerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1c787ff43b9ac..7860a0df4bb2d 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1289,11 +1289,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", Success, tasks(0), null)) time += 1 tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", Success, tasks(1), null)) stage.failureReason = Some("Failed") @@ -1303,11 +1303,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(2).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) time += 1 tasks(3).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) From 9deaa726ef1645746892a23d369c3d14677a48ff Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Nov 2018 15:33:21 -0800 Subject: [PATCH 128/145] [INFRA] Close stale PR. Closes #23107 From c995e0737de66441052fbf0fb941c5ea05d0163f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 26 Nov 2018 17:01:56 -0800 Subject: [PATCH 129/145] [SPARK-26140] followup: rename ShuffleMetricsReporter ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/23105, due to working on two parallel PRs at once, I made the mistake of committing the copy of the PR that used the name ShuffleMetricsReporter for the interface, rather than the appropriate one ShuffleReadMetricsReporter. This patch fixes that. ## How was this patch tested? This should be fine as long as compilation passes. Closes #23147 from rxin/ShuffleReadMetricsReporter. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../spark/executor/ShuffleReadMetrics.scala | 4 +-- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../apache/spark/shuffle/ShuffleManager.scala | 2 +- .../shuffle/ShuffleMetricsReporter.scala | 33 ------------------- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 4 +-- 6 files changed, 7 insertions(+), 40 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 2f97e969d2dd2..12c4b8f67f71c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -18,7 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.ShuffleMetricsReporter +import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.LongAccumulator @@ -130,7 +130,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at * last. */ -private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter { +private[spark] class TempShuffleReadMetrics extends ShuffleReadMetricsReporter { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7cb031ce318b7..27e2f98c58f0c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -33,7 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, - readMetrics: ShuffleMetricsReporter, + readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index d1061d83cb85a..df601cbdb2050 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -49,7 +49,7 @@ private[spark] trait ShuffleManager { startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleMetricsReporter): ShuffleReader[K, C] + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala deleted file mode 100644 index 32865149c97c2..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -/** - * An interface for reporting shuffle information, for each shuffle. This interface assumes - * all the methods are called on a single-threaded, i.e. concrete implementations would not need - * to synchronize anything. - */ -private[spark] trait ShuffleMetricsReporter { - def incRemoteBlocksFetched(v: Long): Unit - def incLocalBlocksFetched(v: Long): Unit - def incRemoteBytesRead(v: Long): Unit - def incRemoteBytesReadToDisk(v: Long): Unit - def incLocalBytesRead(v: Long): Unit - def incFetchWaitTime(v: Long): Unit - def incRecordsRead(v: Long): Unit -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 57c3150e5a697..4f8be198e4a72 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -115,7 +115,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = { + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, metrics) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a2e0713e70b04..86f7c08eddcb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter} +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -73,7 +73,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, - shuffleMetrics: ShuffleMetricsReporter) + shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ From 1c487f7d1442a7043e7faff76ab67a633edc7b05 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Nov 2018 12:13:48 +0800 Subject: [PATCH 130/145] [SPARK-24762][SQL][FOLLOWUP] Enable Option of Product encoders ## What changes were proposed in this pull request? This is follow-up of #21732. This patch inlines `isOptionType` method. ## How was this patch tested? Existing tests. Closes #23143 from viirya/SPARK-24762-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index d019924711e3e..589e215c55e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -251,11 +251,6 @@ case class ExpressionEncoder[T]( */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] - /** - * Returns true if the type `T` is an `Option` type. - */ - def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) - /** * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in * the struct are naturally mapped to top-level columns in a row. In other words, the serialized @@ -263,7 +258,9 @@ case class ExpressionEncoder[T]( * flattened to top-level row, because in Spark SQL top-level row can't be null. This method * returns true if `T` is serialized as struct and is not `Option` type. */ - def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + def isSerializedAsStructForTopLevel: Boolean = { + isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + } // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This From 85383d29ede19dd73949fe57cadb73ec94b29334 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Nov 2018 04:51:32 +0000 Subject: [PATCH 131/145] [SPARK-25860][SPARK-26107][FOLLOW-UP] Rule ReplaceNullWithFalseInPredicate ## What changes were proposed in this pull request? Based on https://github.com/apache/spark/pull/22857 and https://github.com/apache/spark/pull/23079, this PR did a few updates - Limit the data types of NULL to Boolean. - Limit the input data type of replaceNullWithFalse to Boolean; throw an exception in the testing mode. - Create a new file for the rule ReplaceNullWithFalseInPredicate - Update the description of this rule. ## How was this patch tested? Added a test case Closes #23139 from gatorsmile/followupSpark-25860. Authored-by: gatorsmile Signed-off-by: DB Tsai --- .../ReplaceNullWithFalseInPredicate.scala | 110 ++++++++++++++++++ .../sql/catalyst/optimizer/expressions.scala | 66 ----------- ...ReplaceNullWithFalseInPredicateSuite.scala | 11 +- 3 files changed, 119 insertions(+), 68 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala new file mode 100644 index 0000000000000..72a60f692ac78 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} +import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils + + +/** + * A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search + * condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator + * "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is + * semantically equivalent to `FalseLiteral` when evaluating the whole search condition. + * + * Please note that FALSE and NULL are not exchangeable in most cases, when the search condition + * contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable + * in very limited cases. + * + * For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch + * conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions. + * + * For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified + * into `Project(Literal(2))`. + */ +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) + } + } + + /** + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => + e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 354efd883f814..468a950fb1087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -736,69 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } - -/** - * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. - * - * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates - * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. - * - * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. - * - * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; - * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually - * `Filter(FalseLiteral)`. - * - * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can - * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` - * can be simplified into `Project(Literal(2))`. - * - * As a result, many unnecessary computations can be removed in the query optimization phase. - */ -object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case p: LogicalPlan => p transformExpressions { - case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) - case cw @ CaseWhen(branches, _) => - val newBranches = branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> value - } - cw.copy(branches = newBranches) - case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - af.copy(function = newLambda) - case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - ae.copy(function = newLambda) - case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - mf.copy(function = newLambda) - } - } - - /** - * Recursively replaces `Literal(null, _)` with `FalseLiteral`. - * - * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit - * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. - */ - private def replaceNullWithFalse(e: Expression): Expression = e match { - case cw: CaseWhen if cw.dataType == BooleanType => - val newBranches = cw.branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> replaceNullWithFalse(value) - } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => - If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) - case And(left, right) => - And(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Or(left, right) => - Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Literal(null, _) => FalseLiteral - case _ => e - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 3a9e6cae0fd87..ee0d04da3e46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -44,8 +44,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { - testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) - testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + } + + test("Not expected type - replaceNullWithFalse") { + val e = intercept[IllegalArgumentException] { + testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral) + }.getMessage + assert(e.contains("but got the type `int` in `CAST(NULL AS INT)")) } test("replace null in branches of If") { From 6a064ba8f271d5f9d04acd41d0eea50a5b0f5018 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 26 Nov 2018 22:35:52 -0800 Subject: [PATCH 132/145] [SPARK-26141] Enable custom metrics implementation in shuffle write ## What changes were proposed in this pull request? This is the write side counterpart to https://github.com/apache/spark/pull/23105 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes #23106 from rxin/SPARK-26141. Authored-by: Reynold Xin Signed-off-by: Reynold Xin --- .../sort/BypassMergeSortShuffleWriter.java | 11 +++++------ .../shuffle/sort/ShuffleExternalSorter.java | 18 ++++++++++++------ .../shuffle/sort/UnsafeShuffleWriter.java | 9 +++++---- .../storage/TimeTrackingOutputStream.java | 7 ++++--- .../spark/executor/ShuffleWriteMetrics.scala | 13 +++++++------ .../spark/scheduler/ShuffleMapTask.scala | 3 ++- .../apache/spark/shuffle/ShuffleManager.scala | 6 +++++- .../shuffle/sort/SortShuffleManager.scala | 10 ++++++---- .../apache/spark/storage/BlockManager.scala | 7 +++---- .../spark/storage/DiskBlockObjectWriter.scala | 4 ++-- .../spark/util/collection/ExternalSorter.scala | 4 ++-- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 6 ++++-- .../scala/org/apache/spark/ShuffleSuite.scala | 12 ++++++++---- .../BypassMergeSortShuffleWriterSuite.scala | 16 ++++++++-------- project/MimaExcludes.scala | 7 ++++++- 15 files changed, 79 insertions(+), 54 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index b020a6d99247b..fda33cd8293d5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -37,12 +37,11 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final Serializer serializer; @@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, - TaskContext taskContext, - SparkConf conf) { + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1c0d664afb138..6ee9d5f0eec3b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,6 +38,7 @@ import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; /** * Force this sorter to spill when there are this many elements in memory. @@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics) { super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), memoryManager.getTungstenMemoryMode()); @@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private void writeSortedFile(boolean isLastFile) { - final ShuffleWriteMetrics writeMetricsToUse; + final ShuffleWriteMetricsReporter writeMetricsToUse; if (isLastFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. @@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) { // // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. - // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); + // SPARK-3577 tracks the spill time separately. + + // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning + // of this method. + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled( + ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..4b0c74341551e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -37,7 +37,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; @@ -47,6 +46,7 @@ import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; @@ -73,7 +73,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -122,7 +122,8 @@ public UnsafeShuffleWriter( SerializedShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) throws IOException { + SparkConf sparkConf, + ShuffleWriteMetricsReporter writeMetrics) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -138,7 +139,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 5d0555a8c28e1..fcba3b73445c9 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -21,7 +21,7 @@ import java.io.OutputStream; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; /** * Intercepts write calls and tracks total time spent writing in order to update shuffle write @@ -30,10 +30,11 @@ @Private public final class TimeTrackingOutputStream extends OutputStream { - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final OutputStream outputStream; - public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + public TimeTrackingOutputStream( + ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) { this.writeMetrics = writeMetrics; this.outputStream = outputStream; } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 0c9da657c2b60..d0b0e7da079c9 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.LongAccumulator @@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { */ def writeTime: Long = _writeTime.sum - private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) - private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) - private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) - private[spark] def decBytesWritten(v: Long): Unit = { + private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] override def decBytesWritten(v: Long): Unit = { _bytesWritten.setValue(bytesWritten - v) } - private[spark] def decRecordsWritten(v: Long): Unit = { + private[spark] override def decRecordsWritten(v: Long): Unit = { _recordsWritten.setValue(recordsWritten - v) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f2cd65fd523ab..5412717d61988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask( var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager - writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) + writer = manager.getWriter[Any, Any]( + dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) writer.stop(success = true).get } catch { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index df601cbdb2050..18a743fbfa6fc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -38,7 +38,11 @@ private[spark] trait ShuffleManager { dependency: ShuffleDependency[K, V, C]): ShuffleHandle /** Get a writer for a given partition. Called on executors by map tasks. */ - def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] /** * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 4f8be198e4a72..b51a843a31c31 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V]( handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get @@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager unsafeShuffleHandle, mapId, context, - env.conf) + env.conf, + metrics) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, - context, - env.conf) + env.conf, + metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index edae2f95fce33..1b617297e0a30 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,10 +33,9 @@ import scala.util.Random import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source @@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -932,7 +931,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { + writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = { val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a024c83d8d8b7..17390f9c60e79 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -20,9 +20,9 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils /** @@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter( syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics, + writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream with Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79222..eac3db01158d0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C]( def nextPartition(): Int = cur._1._1 } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " + + s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) forceSpillFiles += spillFile val spillReader = new SpillReader(spillFile) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a07d0e84ea854..30ad3f5575545 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -162,7 +162,8 @@ private UnsafeShuffleWriter createWriter( new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics() ); } @@ -521,7 +522,8 @@ public void testPeakMemoryUsed() throws Exception { new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf); + conf, + taskContext.taskMetrics().shuffleWriteMetrics()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 419a26b857ea2..35f728cd57fe2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful - val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context1 = + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer1 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. - val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context2 = + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer2 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 85ccb33471048..4467c3241a947 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(records) writer.stop( /* success = */ true) @@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { @@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 333adb0c84025..3fabec0f60125 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -226,7 +226,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), + + // [SPARK-26141] Enable custom metrics implementation in shuffle write + // Following are Java private classes + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this") ) // Exclude rules for 2.4.x From 65244b1d790699b6a3a29f2fa111d35f9809111a Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Tue, 27 Nov 2018 20:10:34 +0800 Subject: [PATCH 133/145] [SPARK-23356][SQL][TEST] add new test cases for a + 1,a + b and Rand in SetOperationSuite ## What changes were proposed in this pull request? The purpose of this PR is supplement new test cases for a + 1,a + b and Rand in SetOperationSuite. It comes from the comment of closed PR:#20541, thanks. ## How was this patch tested? add new test cases Closes #23138 from heary-cao/UnionPushTestCases. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../optimizer/SetOperationSuite.scala | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index da3923f8d6477..17e00c9a3ead2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanOrEqual, If, Literal, Rand, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -196,4 +196,31 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("SPARK-23356 union: expressions with literal in project list are pushed down") { + val unionQuery = testUnion.select(('a + 1).as("aa")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 1).as("aa")) :: + testRelation2.select(('d + 1).as("aa")) :: + testRelation3.select(('g + 1).as("aa")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: expressions in project list are pushed down") { + val unionQuery = testUnion.select(('a + 'b).as("ab")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 'b).as("ab")) :: + testRelation2.select(('d + 'e).as("ab")) :: + testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: no pushdown for non-deterministic expression") { + val unionQuery = testUnion.select('a, Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = unionQuery.analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } } From 2d89d109e19d1e84c4ada3c9d5d48cfcf3d997ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 27 Nov 2018 09:09:16 -0800 Subject: [PATCH 134/145] [SPARK-26025][K8S] Speed up docker image build on dev repo. The "build context" for a docker image - basically the whole contents of the current directory where "docker" is invoked - can be huge in a dev build, easily breaking a couple of gigs. Doing that copy 3 times during the build of docker images severely slows down the process. This patch creates a smaller build context - basically mimicking what the make-distribution.sh script does, so that when building the docker images, only the necessary bits are in the current directory. For PySpark and R that is optimized further, since those images are built based on the previously built Spark main image. In my current local clone, the dir size is about 2G, but with this script the "context" sent to docker is about 250M for the main image, 1M for the pyspark image and 8M for the R image. That speeds up the image builds considerably. I also snuck in a fix to the k8s integration test dependencies in the sbt build, so that the examples are properly built (without having to do it manually). Closes #23019 from vanzin/SPARK-26025. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 122 ++++++++++++------ project/SparkBuild.scala | 3 +- .../src/main/dockerfiles/spark/Dockerfile | 14 +- 3 files changed, 91 insertions(+), 48 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index e51201a77cb5d..9f735f1148da4 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -29,6 +29,20 @@ if [ -z "${SPARK_HOME}" ]; then fi . "${SPARK_HOME}/bin/load-spark-env.sh" +CTX_DIR="$SPARK_HOME/target/tmp/docker" + +function is_dev_build { + [ ! -f "$SPARK_HOME/RELEASE" ] +} + +function cleanup_ctx_dir { + if is_dev_build; then + rm -rf "$CTX_DIR" + fi +} + +trap cleanup_ctx_dir EXIT + function image_ref { local image="$1" local add_repo="${2:-1}" @@ -53,80 +67,114 @@ function docker_push { fi } +# Create a smaller build context for docker in dev builds to make the build faster. Docker +# uploads all of the current directory to the daemon, and it can get pretty big with dev +# builds that contain test log files and other artifacts. +# +# Three build contexts are created, one for each image: base, pyspark, and sparkr. For them +# to have the desired effect, the docker command needs to be executed inside the appropriate +# context directory. +# +# Note: docker does not support symlinks in the build context. +function create_dev_build_context {( + set -e + local BASE_CTX="$CTX_DIR/base" + mkdir -p "$BASE_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$BASE_CTX/kubernetes/dockerfiles" + + cp -r "assembly/target/scala-$SPARK_SCALA_VERSION/jars" "$BASE_CTX/jars" + cp -r "resource-managers/kubernetes/integration-tests/tests" \ + "$BASE_CTX/kubernetes/tests" + + mkdir "$BASE_CTX/examples" + cp -r "examples/src" "$BASE_CTX/examples/src" + # Copy just needed examples jars instead of everything. + mkdir "$BASE_CTX/examples/jars" + for i in examples/target/scala-$SPARK_SCALA_VERSION/jars/*; do + if [ ! -f "$BASE_CTX/jars/$(basename $i)" ]; then + cp $i "$BASE_CTX/examples/jars" + fi + done + + for other in bin sbin data; do + cp -r "$other" "$BASE_CTX/$other" + done + + local PYSPARK_CTX="$CTX_DIR/pyspark" + mkdir -p "$PYSPARK_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$PYSPARK_CTX/kubernetes/dockerfiles" + mkdir "$PYSPARK_CTX/python" + cp -r "python/lib" "$PYSPARK_CTX/python/lib" + + local R_CTX="$CTX_DIR/sparkr" + mkdir -p "$R_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$R_CTX/kubernetes/dockerfiles" + cp -r "R" "$R_CTX/R" +)} + +function img_ctx_dir { + if is_dev_build; then + echo "$CTX_DIR/$1" + else + echo "$SPARK_HOME" + fi +} + function build { local BUILD_ARGS - local IMG_PATH - local JARS - - if [ ! -f "$SPARK_HOME/RELEASE" ]; then - # Set image build arguments accordingly if this is a source repo and not a distribution archive. - # - # Note that this will copy all of the example jars directory into the image, and that will - # contain a lot of duplicated jars with the main Spark directory. In a proper distribution, - # the examples directory is cleaned up before generating the distribution tarball, so this - # issue does not occur. - IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles - JARS=assembly/target/scala-$SPARK_SCALA_VERSION/jars - BUILD_ARGS=( - ${BUILD_PARAMS} - --build-arg - img_path=$IMG_PATH - --build-arg - spark_jars=$JARS - --build-arg - example_jars=examples/target/scala-$SPARK_SCALA_VERSION/jars - --build-arg - k8s_tests=resource-managers/kubernetes/integration-tests/tests - ) - else - # Not passed as arguments to docker, but used to validate the Spark directory. - IMG_PATH="kubernetes/dockerfiles" - JARS=jars - BUILD_ARGS=(${BUILD_PARAMS}) + local SPARK_ROOT="$SPARK_HOME" + + if is_dev_build; then + create_dev_build_context || error "Failed to create docker build context." + SPARK_ROOT="$CTX_DIR/base" fi # Verify that the Docker image content directory is present - if [ ! -d "$IMG_PATH" ]; then + if [ ! -d "$SPARK_ROOT/kubernetes/dockerfiles" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi # Verify that Spark has actually been built/is a runnable distribution # i.e. the Spark JARs that the Docker files will place into the image are present - local TOTAL_JARS=$(ls $JARS/spark-* | wc -l) + local TOTAL_JARS=$(ls $SPARK_ROOT/jars/spark-* | wc -l) TOTAL_JARS=$(( $TOTAL_JARS )) if [ "${TOTAL_JARS}" -eq 0 ]; then error "Cannot find Spark JARs. This script assumes that Apache Spark has first been built locally or this is a runnable distribution." fi + local BUILD_ARGS=(${BUILD_PARAMS}) local BINDING_BUILD_ARGS=( ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) - local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local BASEDOCKERFILE=${BASEDOCKERFILE:-"kubernetes/dockerfiles/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-false} local RDOCKERFILE=${RDOCKERFILE:-false} - docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir base) && docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$BASEDOCKERFILE" . + -f "$BASEDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi if [ "${PYDOCKERFILE}" != "false" ]; then - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir pyspark) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ - -f "$PYDOCKERFILE" . + -f "$PYDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build PySpark Docker image, please refer to Docker build output for details." fi fi if [ "${RDOCKERFILE}" != "false" ]; then - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir sparkr) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-r) \ - -f "$RDOCKERFILE" . + -f "$RDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 08e22fab65165..bb834bc483f1f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -516,7 +516,8 @@ object KubernetesIntegrationTests { s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome" ), // Force packaging before building images, so that the latest code is tested. - dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly).value + dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly) + .dependsOn(packageBin in Compile in examples).value ) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 5f469c30a96fa..89b20e1446229 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,11 +17,6 @@ FROM openjdk:8-alpine -ARG spark_jars=jars -ARG example_jars=examples/jars -ARG img_path=kubernetes/dockerfiles -ARG k8s_tests=kubernetes/tests - # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -41,13 +36,12 @@ RUN set -ex && \ echo "auth required pam_wheel.so use_uid" >> /etc/pam.d/su && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY ${spark_jars} /opt/spark/jars +COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY ${img_path}/spark/entrypoint.sh /opt/ -COPY ${example_jars} /opt/spark/examples/jars -COPY examples/src /opt/spark/examples/src -COPY ${k8s_tests} /opt/spark/tests +COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark From 8c6871828e3eb9fdb3bc665441a1aaf60b86b1e7 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Nov 2018 13:37:11 +0800 Subject: [PATCH 135/145] [SPARK-26159] Codegen for LocalTableScanExec and RDDScanExec ## What changes were proposed in this pull request? Implement codegen for `LocalTableScanExec` and `ExistingRDDExec`. Refactor to share code between `LocalTableScanExec`, `ExistingRDDExec`, `InputAdapter` and `RowDataSourceScanExec`. The difference in `doProduce` between these four was that `ExistingRDDExec` and `RowDataSourceScanExec` triggered adding an `UnsafeProjection`, while `InputAdapter` and `LocalTableScanExec` did not. In the new trait `InputRDDCodegen` I added a flag `createUnsafeProjection` which the operators set accordingly. Note: `LocalTableScanExec` explicitly creates its input as `UnsafeRows`, so it was obvious why it doesn't need an `UnsafeProjection`. But if an `InputAdapter` may take input that is `InternalRows` but not `UnsafeRows`, then I think it doesn't need an unsafe projection just because any other operator that is its parent would do that. That assumes that that any parent operator would always result in some `UnsafeProjection` being eventually added, and hence the output of the `WholeStageCodegen` unit would be `UnsafeRows`. If these assumptions hold, I think `createUnsafeProjection` could be set to `(parent == null)`. Note: Do not codegen `LocalTableScanExec` when it's the only operator. `LocalTableScanExec` has optimized driver-only `executeCollect` and `executeTake` code paths that are used to return `Command` results without starting Spark Jobs. They can no longer be used if the `LocalTableScanExec` gets optimized. ## How was this patch tested? Covered and used in existing tests. Closes #23127 from juliuszsompolski/SPARK-26159. Authored-by: Juliusz Sompolski Signed-off-by: Wenchen Fan --- python/pyspark/sql/dataframe.py | 2 +- .../sql/execution/DataSourceScanExec.scala | 28 +------ .../spark/sql/execution/ExistingRDD.scala | 7 +- .../sql/execution/LocalTableScanExec.scala | 10 ++- .../sql/execution/WholeStageCodegenExec.scala | 78 ++++++++++++++----- .../sql-tests/results/operators.sql.out | 12 +-- 6 files changed, 86 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ca15b36699166..b8833a39078ba 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -257,7 +257,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + *(1) Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 77e381ef6e6b4..4faa27c2c1e23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -84,7 +84,7 @@ case class RowDataSourceScanExec( rdd: RDD[InternalRow], @transient relation: BaseRelation, override val tableIdentifier: Option[TableIdentifier]) - extends DataSourceScanExec { + extends DataSourceScanExec with InputRDDCodegen { def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) @@ -104,30 +104,10 @@ case class RowDataSourceScanExec( } } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - rdd :: Nil - } + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true - override protected def doProduce(ctx: CodegenContext): String = { - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } - val row = ctx.freshName("row") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } + override def inputRDD: RDD[InternalRow] = rdd override val metadata: Map[String, String] = { val markedFilters = for (filter <- filters) yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 9f67d556af362..e214bfd050410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -175,7 +175,7 @@ case class RDDScanExec( rdd: RDD[InternalRow], name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), - override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode with InputRDDCodegen { private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") @@ -199,4 +199,9 @@ case class RDDScanExec( override def simpleString: String = { s"$nodeName${truncatedString(output, "[", ",", "]")}" } + + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 448eb703eacde..31640db3722ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - @transient rows: Seq[InternalRow]) extends LeafExecNode { + @transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -76,4 +76,12 @@ case class LocalTableScanExec( longMetric("numOutputRows").add(taken.size) taken } + + // Input is already UnsafeRows. + override protected val createUnsafeProjection: Boolean = false + + // Do not codegen when there is no parent - to support the fast driver-local collect/take paths. + override def supportCodegen: Boolean = (parent != null) + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 29bcbcae366c5..fbda0d87a175f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -350,6 +350,15 @@ trait CodegenSupport extends SparkPlan { */ def needStopCheck: Boolean = parent.needStopCheck + /** + * Helper default should stop check code. + */ + def shouldStopCheckCode: String = if (needStopCheck) { + "if (shouldStop()) return;" + } else { + "// shouldStop check is eliminated" + } + /** * A sequence of checks which evaluate to true if the downstream Limit operators have not received * enough records and reached the limit. If current node is a data producing node, it can leverage @@ -406,6 +415,53 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { override def limitNotReachedChecks: Seq[String] = Nil } +/** + * Leaf codegen node reading from a single RDD. + */ +trait InputRDDCodegen extends CodegenSupport { + + def inputRDD: RDD[InternalRow] + + // If the input can be InternalRows, an UnsafeProjection needs to be created. + protected val createUnsafeProjection: Boolean + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + override def doProduce(ctx: CodegenContext): String = { + // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + forceInline = true) + val row = ctx.freshName("row") + + val outputVars = if (createUnsafeProjection) { + // creating the vars will make the parent consume add an unsafe projection. + ctx.INPUT_ROW = row + ctx.currentVars = null + output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } else { + null + } + + val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) { + val numOutputRows = metricTerm(ctx, "numOutputRows") + s"$numOutputRows.add(1);" + } else { + "" + } + s""" + | while ($limitNotReachedCond $input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | ${updateNumOutputRowsMetrics} + | ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim} + | ${shouldStopCheckCode} + | } + """.stripMargin + } +} /** * InputAdapter is used to hide a SparkPlan from a subtree that supports codegen. @@ -413,7 +469,7 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen { override def output: Seq[Attribute] = child.output @@ -429,24 +485,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp child.doExecuteBroadcast() } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.execute() :: Nil - } + override def inputRDD: RDD[InternalRow] = child.execute() - override def doProduce(ctx: CodegenContext): String = { - // Right now, InputAdapter is only used when there is one input RDD. - // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - forceInline = true) - val row = ctx.freshName("row") - s""" - | while ($limitNotReachedCond $input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | ${consume(ctx, null, row).trim} - | if (shouldStop()) return; - | } - """.stripMargin - } + // InputAdapter does not need UnsafeProjection. + protected val createUnsafeProjection: Boolean = false override def generateTreeString( depth: Int, diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f78..570b281353f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -201,7 +201,7 @@ struct -- !query 24 output == Physical Plan == *Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 25 @@ -211,7 +211,7 @@ struct -- !query 25 output == Physical Plan == *Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 26 @@ -221,7 +221,7 @@ struct -- !query 26 output == Physical Plan == *Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 27 @@ -231,7 +231,7 @@ struct -- !query 27 output == Physical Plan == *Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 28 @@ -241,7 +241,7 @@ struct -- !query 28 output == Physical Plan == *Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 29 @@ -251,7 +251,7 @@ struct -- !query 29 output == Physical Plan == *Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 30 From 09a91d98bdecb86ecad4647b7ef5fb3f69bdc671 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 16:21:42 +0800 Subject: [PATCH 136/145] [SPARK-26021][SQL][FOLLOWUP] add test for special floating point values ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/23043 . Add a test to show the minor behavior change introduced by #23043 , and add migration guide. ## How was this patch tested? a new test Closes #23141 from cloud-fan/follow. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/unsafe/PlatformUtilSuite.java | 12 +++++--- docs/sql-migration-guide-upgrade.md | 6 ++-- .../catalyst/expressions/UnsafeArrayData.java | 6 ---- .../spark/sql/DatasetPrimitiveSuite.scala | 29 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 7 +++++ 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index ab34324eb54cc..2474081dad5c9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -165,10 +165,14 @@ public void writeMinusZeroIsReplacedWithZero() { byte[] floatBytes = new byte[Float.BYTES]; Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); - float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); - Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); - Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + byte[] doubleBytes2 = new byte[Double.BYTES]; + byte[] floatBytes2 = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); + + // Make sure the bytes we write from 0.0 and -0.0 are same. + Assert.assertArrayEquals(doubleBytes, doubleBytes2); + Assert.assertArrayEquals(floatBytes, floatBytes2); } } diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 68cb8f5a0d18c..25cd541190919 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,14 +17,16 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. - - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. - + - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. + - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 9002abdcfd474..d5f679fe23d48 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -334,17 +334,11 @@ public void setLong(int ordinal, long value) { } public void setFloat(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value); } public void setDouble(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 96a6792f52f3e..0ded5d8ce1e28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -393,4 +393,33 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { val ds = spark.createDataset(data) checkDataset(ds, data: _*) } + + test("special floating point values") { + import org.scalatest.exceptions.TestFailedException + + // Spark treats -0.0 as 0.0 + intercept[TestFailedException] { + checkDataset(Seq(-0.0d).toDS(), -0.0d) + } + intercept[TestFailedException] { + checkDataset(Seq(-0.0f).toDS(), -0.0f) + } + intercept[TestFailedException] { + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(-0.0)) + } + + val floats = Seq[Float](-0.0f, 0.0f, Float.NaN).toDS() + checkDataset(floats, 0.0f, 0.0f, Float.NaN) + + val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN).toDS() + checkDataset(doubles, 0.0, 0.0, Double.NaN) + + checkDataset(Seq(Tuple1(Float.NaN)).toDS(), Tuple1(Float.NaN)) + checkDataset(Seq(Tuple1(-0.0f)).toDS(), Tuple1(0.0f)) + checkDataset(Seq(Tuple1(Double.NaN)).toDS(), Tuple1(Double.NaN)) + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + + val complex = Map(Array(Seq(Tuple1(Double.NaN))) -> Map(Tuple2(Float.NaN, null))) + checkDataset(Seq(complex).toDS(), complex) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8ba67239fb907..a547676c5ed5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -132,6 +132,13 @@ abstract class QueryTest extends PlanTest { a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + // 0.0 == -0.0, turn float/double to binary before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) case (a, b) => a == b } From 93112e693082f3fba24cebaf9a98dcf5c1eb84af Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 28 Nov 2018 20:18:13 +0800 Subject: [PATCH 137/145] [SPARK-26142][SQL] Implement shuffle read metrics in SQL ## What changes were proposed in this pull request? Implement `SQLShuffleMetricsReporter` on the sql side as the customized ShuffleMetricsReporter, which extended the `TempShuffleReadMetrics` and update SQLMetrics, in this way shuffle metrics can be reported in the SQL UI. ## How was this patch tested? Add UT in SQLMetricsSuite. Manual test locally, before: ![image](https://user-images.githubusercontent.com/4833765/48960517-30f97880-efa8-11e8-982c-92d05938fd1d.png) after: ![image](https://user-images.githubusercontent.com/4833765/48960587-b54bfb80-efa8-11e8-8e95-7a3c8c74cc5c.png) Closes #23128 from xuanyuanking/SPARK-26142. Lead-authored-by: Yuanjian Li Co-authored-by: liyuanjian Signed-off-by: Wenchen Fan --- .../spark/sql/execution/ShuffledRowRDD.scala | 9 ++- .../exchange/ShuffleExchangeExec.scala | 5 +- .../apache/spark/sql/execution/limit.scala | 10 ++- .../sql/execution/metric/SQLMetrics.scala | 20 ++++++ .../metric/SQLShuffleMetricsReporter.scala | 67 +++++++++++++++++++ .../execution/UnsafeRowSerializerSuite.scala | 5 +- .../execution/metric/SQLMetricsSuite.scala | 21 ++++-- 7 files changed, 126 insertions(+), 11 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 542266bc1ae07..9b05faaed0459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -22,6 +22,7 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter} /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -112,6 +113,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A */ class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric], specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { @@ -154,7 +156,10 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val metrics = context.taskMetrics().createTempShuffleReadMetrics() + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -163,7 +168,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - metrics) + sqlMetricsReporter) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index d6742ab3e0f31..8938d93da90eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -48,7 +48,8 @@ case class ShuffleExchangeExec( // e.g. it can be null on the Executor side override lazy val metrics = Map( - "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") + ) ++ SQLMetrics.getShuffleReadMetrics(sparkContext) override def nodeName: String = { val extraInfo = coordinator match { @@ -108,7 +109,7 @@ case class ShuffleExchangeExec( assert(newPartitioning.isInstanceOf[HashPartitioning]) newPartitioning = UnknownPartitioning(indices.length) } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 90dafcf535914..ea845da8438fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Take the first `limit` elements and collect them to a single partition. @@ -37,11 +38,13 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - locallyLimited, child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -151,6 +154,8 @@ case class TakeOrderedAndProjectExec( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val localTopK: RDD[InternalRow] = { @@ -160,7 +165,8 @@ case class TakeOrderedAndProjectExec( } val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - localTopK, child.output, SinglePartition, serializer)) + localTopK, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList != child.output) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index cbf707f4a9cfd..0b5ee3a5e0577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,6 +82,14 @@ object SQLMetrics { private val baseForAvgMetric: Int = 10 + val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" + val REMOTE_BYTES_READ = "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk" + val LOCAL_BYTES_READ = "localBytesRead" + val FETCH_WAIT_TIME = "fetchWaitTime" + val RECORDS_READ = "recordsRead" + /** * Converts a double value to long value by multiplying a base integer, so we can store it in * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore @@ -194,4 +202,16 @@ object SQLMetrics { SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) } } + + /** + * Create all shuffle read relative metrics and return the Map. + */ + def getShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( + REMOTE_BLOCKS_FETCHED -> createMetric(sc, "remote blocks fetched"), + LOCAL_BLOCKS_FETCHED -> createMetric(sc, "local blocks fetched"), + REMOTE_BYTES_READ -> createSizeMetric(sc, "remote bytes read"), + REMOTE_BYTES_READ_TO_DISK -> createSizeMetric(sc, "remote bytes read to disk"), + LOCAL_BYTES_READ -> createSizeMetric(sc, "local bytes read"), + FETCH_WAIT_TIME -> createTimingMetric(sc, "fetch wait time"), + RECORDS_READ -> createMetric(sc, "records read")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala new file mode 100644 index 0000000000000..542141ea4b4e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.executor.TempShuffleReadMetrics + +/** + * A shuffle metrics reporter for SQL exchange operators. + * @param tempMetrics [[TempShuffleReadMetrics]] created in TaskContext. + * @param metrics All metrics in current SparkPlan. This param should not empty and + * contains all shuffle metrics defined in [[SQLMetrics.getShuffleReadMetrics]]. + */ +private[spark] class SQLShuffleMetricsReporter( + tempMetrics: TempShuffleReadMetrics, + metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { + private[this] val _remoteBlocksFetched = metrics(SQLMetrics.REMOTE_BLOCKS_FETCHED) + private[this] val _localBlocksFetched = metrics(SQLMetrics.LOCAL_BLOCKS_FETCHED) + private[this] val _remoteBytesRead = metrics(SQLMetrics.REMOTE_BYTES_READ) + private[this] val _remoteBytesReadToDisk = metrics(SQLMetrics.REMOTE_BYTES_READ_TO_DISK) + private[this] val _localBytesRead = metrics(SQLMetrics.LOCAL_BYTES_READ) + private[this] val _fetchWaitTime = metrics(SQLMetrics.FETCH_WAIT_TIME) + private[this] val _recordsRead = metrics(SQLMetrics.RECORDS_READ) + + override def incRemoteBlocksFetched(v: Long): Unit = { + _remoteBlocksFetched.add(v) + tempMetrics.incRemoteBlocksFetched(v) + } + override def incLocalBlocksFetched(v: Long): Unit = { + _localBlocksFetched.add(v) + tempMetrics.incLocalBlocksFetched(v) + } + override def incRemoteBytesRead(v: Long): Unit = { + _remoteBytesRead.add(v) + tempMetrics.incRemoteBytesRead(v) + } + override def incRemoteBytesReadToDisk(v: Long): Unit = { + _remoteBytesReadToDisk.add(v) + tempMetrics.incRemoteBytesReadToDisk(v) + } + override def incLocalBytesRead(v: Long): Unit = { + _localBytesRead.add(v) + tempMetrics.incLocalBytesRead(v) + } + override def incFetchWaitTime(v: Long): Unit = { + _fetchWaitTime.add(v) + tempMetrics.incFetchWaitTime(v) + } + override def incRecordsRead(v: Long): Unit = { + _recordsRead.add(v) + tempMetrics.incRecordsRead(v) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index d305ce3e698ae..96b3aa5ee75b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -137,7 +138,9 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { rowsRDD, new PartitionIdPassthrough(2), new UnsafeRowSerializer(2)) - val shuffled = new ShuffledRowRDD(dependency) + val shuffled = new ShuffledRowRDD( + dependency, + SQLMetrics.getShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index b955c157a620e..0f1d08b6af5d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -94,8 +94,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected1 = Map( + "records read" -> 2L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), + 1L -> (("Exchange", shuffleExpected1)), 0L -> (("HashAggregate", expected1(1)))) ) @@ -106,8 +111,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected2 = Map( + "records read" -> 4L, + "local blocks fetched" -> 4L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), + 1L -> (("Exchange", shuffleExpected2)), 0L -> (("HashAggregate", expected2(1)))) ) } @@ -191,7 +201,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 4L)))) + "number of output rows" -> 4L))), + 2L -> (("Exchange", Map( + "records read" -> 4L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L)))) ) } } @@ -208,7 +222,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) @@ -216,7 +230,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) } @@ -287,7 +301,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) val df = df1.join(df2, "key") - val metrics = getSparkPlanMetrics(df, 1, Set(1L)) testSparkPlanMetrics(df, 1, Map( 1L -> (("ShuffledHashJoin", Map( "number of output rows" -> 2L, From 438f8fd675d8f819373b6643dea3a77d954b6822 Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Wed, 28 Nov 2018 20:22:24 +0800 Subject: [PATCH 138/145] [SPARK-26114][CORE] ExternalSorter's readingIterator field leak ## What changes were proposed in this pull request? This pull request fixes [SPARK-26114](https://issues.apache.org/jira/browse/SPARK-26114) issue that occurs when trying to reduce the number of partitions by means of coalesce without shuffling after shuffle-based transformations. The leak occurs because of not cleaning up `ExternalSorter`'s `readingIterator` field as it's done for its `map` and `buffer` fields. Additionally there are changes to the `CompletionIterator` to prevent capturing its `sub`-iterator and holding it even after the completion iterator completes. It is necessary because in some cases, e.g. in case of standard scala's `flatMap` iterator (which is used is `CoalescedRDD`'s `compute` method) the next value of the main iterator is assigned to `flatMap`'s `cur` field only after it is available. For DAGs where ShuffledRDD is a parent of CoalescedRDD it means that the data should be fetched from the map-side of the shuffle, but the process of fetching this data consumes quite a lot of memory in addition to the memory already consumed by the iterator held by `flatMap`'s `cur` field (until it is reassigned). For the following data ```scala import org.apache.hadoop.io._ import org.apache.hadoop.io.compress._ import org.apache.commons.lang._ import org.apache.spark._ // generate 100M records of sample data sc.makeRDD(1 to 1000, 1000) .flatMap(item => (1 to 100000) .map(i => new Text(RandomStringUtils.randomAlphanumeric(3).toLowerCase) -> new Text(RandomStringUtils.randomAlphanumeric(1024)))) .saveAsSequenceFile("/tmp/random-strings", Some(classOf[GzipCodec])) ``` and the following job ```scala import org.apache.hadoop.io._ import org.apache.spark._ import org.apache.spark.storage._ val rdd = sc.sequenceFile("/tmp/random-strings", classOf[Text], classOf[Text]) rdd .map(item => item._1.toString -> item._2.toString) .repartitionAndSortWithinPartitions(new HashPartitioner(1000)) .coalesce(10,false) .count ``` ... executed like the following ```bash spark-shell \ --num-executors=5 \ --executor-cores=2 \ --master=yarn \ --deploy-mode=client \ --conf spark.executor.memoryOverhead=512 \ --conf spark.executor.memory=1g \ --conf spark.dynamicAllocation.enabled=false \ --conf spark.executor.extraJavaOptions='-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp -Dio.netty.noUnsafe=true' ``` ... executors are always failing with OutOfMemoryErrors. The main issue is multiple leaks of ExternalSorter references. For example, in case of 2 tasks per executor it is expected to be 2 simultaneous instances of ExternalSorter per executor but heap dump generated on OutOfMemoryError shows that there are more ones. ![run1-noparams-dominator-tree-externalsorter](https://user-images.githubusercontent.com/1523889/48703665-782ce580-ec05-11e8-95a9-d6c94e8285ab.png) P.S. This PR does not cover cases with CoGroupedRDDs which use ExternalAppendOnlyMap internally, which itself can lead to OutOfMemoryErrors in many places. ## How was this patch tested? - Existing unit tests - New unit tests - Job executions on the live environment Here is the screenshot before applying this patch ![run3-noparams-failure-ui-5x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700395-f769eb80-ebfc-11e8-831b-e94c757d416c.png) Here is the screenshot after applying this patch ![run3-noparams-success-ui-5x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700610-7a8b4180-ebfd-11e8-9761-baaf38a58e66.png) And in case of reducing the number of executors even more the job is still stable ![run3-noparams-success-ui-2x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700619-82e37c80-ebfd-11e8-98ed-a38e1f1f1fd9.png) Closes #23083 from szhem/SPARK-26114-externalsorter-leak. Authored-by: Sergey Zhemzhitsky Signed-off-by: Wenchen Fan --- .../spark/util/CompletionIterator.scala | 7 ++++-- .../util/collection/ExternalSorter.scala | 3 ++- .../spark/util/CompletionIteratorSuite.scala | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 21acaa95c5645..f4d6c7a28d2e4 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,11 +25,14 @@ private[spark] abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { private[this] var completed = false - def next(): A = sub.next() + private[this] var iter = sub + def next(): A = iter.next() def hasNext: Boolean = { - val r = sub.hasNext + val r = iter.hasNext if (!r && !completed) { completed = true + // reassign to release resources of highly resource consuming iterators early + iter = Iterator.empty.asInstanceOf[I] completion() } r diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eac3db01158d0..46279e79d78db 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -727,9 +727,10 @@ private[spark] class ExternalSorter[K, V, C]( spills.clear() forceSpillFiles.foreach(s => s.file.delete()) forceSpillFiles.clear() - if (map != null || buffer != null) { + if (map != null || buffer != null || readingIterator != null) { map = null // So that the memory can be garbage-collected buffer = null // So that the memory can be garbage-collected + readingIterator = null // So that the memory can be garbage-collected releaseMemory() } } diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 688fcd9f9aaba..29421f7aa9e36 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.util +import java.lang.ref.PhantomReference +import java.lang.ref.ReferenceQueue + import org.apache.spark.SparkFunSuite class CompletionIteratorSuite extends SparkFunSuite { @@ -44,4 +47,23 @@ class CompletionIteratorSuite extends SparkFunSuite { assert(!completionIter.hasNext) assert(numTimesCompleted === 1) } + test("reference to sub iterator should not be available after completion") { + var sub = Iterator(1, 2, 3) + + val refQueue = new ReferenceQueue[Iterator[Int]] + val ref = new PhantomReference[Iterator[Int]](sub, refQueue) + + val iter = CompletionIterator[Int, Iterator[Int]](sub, {}) + sub = null + iter.toArray + + for (_ <- 1 to 100 if !ref.isEnqueued) { + System.gc() + if (!ref.isEnqueued) { + Thread.sleep(10) + } + } + assert(ref.isEnqueued) + assert(refQueue.poll() === ref) + } } From affe80958d366f399466a9dba8e03da7f3b7b9bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 20:38:42 +0800 Subject: [PATCH 139/145] [SPARK-26147][SQL] only pull out unevaluable python udf from join condition ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. ## How was this patch tested? a new test Closes #23153 from cloud-fan/join. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udf.py | 12 ++ .../spark/sql/catalyst/optimizer/joins.scala | 22 ++-- ...PullOutPythonUDFInJoinConditionSuite.scala | 120 ++++++++++++------ 3 files changed, 106 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index d2dfb52f54475..ed298f724d551 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -209,6 +209,18 @@ def test_udf_in_join_condition(self): with self.sql_conf({"spark.sql.crossJoin.enabled": True}): self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_outer_join_condition(self): + # regression test for SPARK-26147 + from pyspark.sql.functions import udf, col + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a: str(a), StringType()) + # The join condition can't be pushed down, as it refers to attributes from both sides. + # The Python UDF only refer to attributes from one side, so it's evaluable. + df = left.join(right, f("a") == col("b").cast("string"), how="left_outer") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_semi_join_condition(self): # regression test for SPARK-25314 from pyspark.sql.functions import udf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 7149edee0173e..6ebb194d71c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF - * and pull them out from join condition. For python udf accessing attributes from only one side, - * they are pushed down by operation push down rules. If not (e.g. user disables filter push - * down rules), we need to pull them out in this rule too. + * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides. + * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them + * out from join condition. */ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { - def hasPythonUDF(expression: Expression): Boolean = { - expression.collectFirst { case udf: PythonUDF => udf }.isDefined + + private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { + expr.find { e => + PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right) + }.isDefined } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, condition) - if condition.isDefined && hasPythonUDF(condition.get) => + case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } // If condition expression contains python udf, it will be moved out from // the new join conditions. - val (udf, rest) = - splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j)) val newCondition = if (rest.isEmpty) { - logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," + s" it will be moved out and the join plan will be turned to cross join.") None } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala index d3867f2b6bd0e..3f1c91df7f2e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.scalatest.Matchers._ - import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf._ -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, IntegerType} class PullOutPythonUDFInJoinConditionSuite extends PlanTest { @@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { CheckCartesianProducts) :: Nil } - val testRelationLeft = LocalRelation('a.int, 'b.int) - val testRelationRight = LocalRelation('c.int, 'd.int) + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int + + val testRelationLeft = LocalRelation(attrA, attrB) + val testRelationRight = LocalRelation(attrC, attrD) + + // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only + // refer to attributes from one side. + val evaluableJoinCond = { + val pythonUDF = PythonUDF("evaluable", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrC + } - // Dummy python UDF for testing. Unable to execute. - val pythonUDF = PythonUDF("pythonUDF", null, + // This join condition is a PythonUDF which refers to attributes from 2 tables. + val unevaluableJoinCond = PythonUDF("unevaluable", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) @@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { } } - test("inner join condition with python udf only") { - val query = testRelationLeft.join( + test("inner join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("left semi join condition with python udf only") { - val query = testRelationLeft.join( + test("left semi join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = LeftSemi, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).select('a, 'b).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("python udf and common condition") { + test("unevaluable python udf and common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF && 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } - test("python udf or common condition") { + test("unevaluable python udf or common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF || 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze comparePlanWithCrossJoinEnable(query, expected) } - test("pull out whole complex condition with multiple python udf") { + test("pull out whole complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1 val query = testRelationLeft.join( testRelationRight, @@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { comparePlanWithCrossJoinEnable(query, expected) } - test("partial pull out complex condition with multiple python udf") { + test("partial pull out complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr val query = testRelationLeft.join( testRelationRight, @@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("pull out unevaluable python udf when it's mixed with evaluable one") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond && unevaluableJoinCond)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } test("throw an exception for not support join type") { for (joinType <- unsupportedJoinTypes) { - val thrownException = the [AnalysisException] thrownBy { + val e = intercept[AnalysisException] { val query = testRelationLeft.join( testRelationRight, joinType, - condition = Some(pythonUDF)) + condition = Some(unevaluableJoinCond)) Optimize.execute(query.analyze) } - assert(thrownException.message.contentEquals( + assert(e.message.contentEquals( s"Using PythonUDF in join condition of join type $joinType is not supported.")) + + val query2 = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } } } - From ce61bac1d84f8577b180400e44bd9bf22292e0b6 Mon Sep 17 00:00:00 2001 From: Mark Pavey Date: Wed, 28 Nov 2018 07:19:47 -0800 Subject: [PATCH 140/145] =?UTF-8?q?[SPARK-26137][CORE]=20Use=20Java=20syst?= =?UTF-8?q?em=20property=20"file.separator"=20inste=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … of hard coded "/" in DependencyUtils ## What changes were proposed in this pull request? Use Java system property "file.separator" instead of hard coded "/" in DependencyUtils. ## How was this patch tested? Manual test: Submit Spark application via REST API that reads data from Elasticsearch using spark-elasticsearch library. Without fix application fails with error: 18/11/22 10:36:20 ERROR Version: Multiple ES-Hadoop versions detected in the classpath; please use only one jar:file:/C:/<...>/spark-2.4.0-bin-hadoop2.6/work/driver-20181122103610-0001/myApp-assembly-1.0.jar jar:file:/C:/<...>/myApp-assembly-1.0.jar 18/11/22 10:36:20 ERROR Main: Application [MyApp] failed: java.lang.Error: Multiple ES-Hadoop versions detected in the classpath; please use only one jar:file:/C:/<...>/spark-2.4.0-bin-hadoop2.6/work/driver-20181122103610-0001/myApp-assembly-1.0.jar jar:file:/C:/<...>/myApp-assembly-1.0.jar at org.elasticsearch.hadoop.util.Version.(Version.java:73) at org.elasticsearch.hadoop.rest.RestService.findPartitions(RestService.java:214) at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions$lzycompute(AbstractEsRDD.scala:73) at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions(AbstractEsRDD.scala:72) at org.elasticsearch.spark.rdd.AbstractEsRDD.getPartitions(AbstractEsRDD.scala:44) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:251) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:49) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:251) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126) at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:363) at org.apache.spark.rdd.RDD.collect(RDD.scala:944) ... at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.deploy.worker.DriverWrapper$.main(DriverWrapper.scala:65) at org.apache.spark.deploy.worker.DriverWrapper.main(DriverWrapper.scala) With fix application runs successfully. Closes #23102 from markpavey/JIRA_SPARK-26137_DependencyUtilsFileSeparatorFix. Authored-by: Mark Pavey Signed-off-by: Sean Owen --- .../apache/spark/deploy/DependencyUtils.scala | 3 ++- .../spark/deploy/SparkSubmitSuite.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 178bdcfccb603..5a17a6b6e169c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -61,11 +61,12 @@ private[deploy] object DependencyUtils extends Logging { hadoopConf: Configuration, secMgr: SecurityManager): String = { val targetDir = Utils.createTempDir() + val userJarName = userJar.split(File.separatorChar).last Option(jars) .map { resolveGlobPaths(_, hadoopConf) .split(",") - .filterNot(_.contains(userJar.split("/").last)) + .filterNot(_.contains(userJarName)) .mkString(",") } .filterNot(_ == "") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 652c36ffa6e71..c093789244bfe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -962,6 +962,25 @@ class SparkSubmitSuite } } + test("remove copies of application jar from classpath") { + val fs = File.separator + val sparkConf = new SparkConf(false) + val hadoopConf = new Configuration() + val secMgr = new SecurityManager(sparkConf) + + val appJarName = "myApp.jar" + val jar1Name = "myJar1.jar" + val jar2Name = "myJar2.jar" + val userJar = s"file:/path${fs}to${fs}app${fs}jar$fs$appJarName" + val jars = s"file:/$jar1Name,file:/$appJarName,file:/$jar2Name" + + val resolvedJars = DependencyUtils + .resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf, secMgr) + + assert(!resolvedJars.contains(appJarName)) + assert(resolvedJars.contains(jar1Name) && resolvedJars.contains(jar2Name)) + } + test("Avoid re-upload remote resources in yarn client mode") { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) From 87bd9c75df6b67bef903751269a4fd381f9140d9 Mon Sep 17 00:00:00 2001 From: Brandon Krieger Date: Wed, 28 Nov 2018 07:22:48 -0800 Subject: [PATCH 141/145] [SPARK-25998][CORE] Change TorrentBroadcast to hold weak reference of broadcast object ## What changes were proposed in this pull request? This PR changes the broadcast object in TorrentBroadcast from a strong reference to a weak reference. This allows it to be garbage collected even if the Dataset is held in memory. This is ok, because the broadcast object can always be re-read. ## How was this patch tested? Tested in Spark shell by taking a heap dump, full repro steps listed in https://issues.apache.org/jira/browse/SPARK-25998. Closes #22995 from bkrieger/bk/torrent-broadcast-weak. Authored-by: Brandon Krieger Signed-off-by: Sean Owen --- .../spark/broadcast/TorrentBroadcast.scala | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index cbd49e070f2eb..26ead57316e18 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,6 +18,7 @@ package org.apache.spark.broadcast import java.io._ +import java.lang.ref.SoftReference import java.nio.ByteBuffer import java.util.zip.Adler32 @@ -61,9 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], * which builds this value by reading blocks from the driver and/or other executors. * - * On the driver, if the value is required, it is read lazily from the block manager. + * On the driver, if the value is required, it is read lazily from the block manager. We hold + * a soft reference so that it can be garbage collected if required, as we can always reconstruct + * in the future. */ - @transient private lazy val _value: T = readBroadcastBlock() + @transient private var _value: SoftReference[T] = _ /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ @@ -92,8 +95,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** The checksum for all the blocks. */ private var checksums: Array[Int] = _ - override protected def getValue() = { - _value + override protected def getValue() = synchronized { + val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get + if (memoized != null) { + memoized + } else { + val newlyRead = readBroadcastBlock() + _value = new SoftReference[T](newlyRead) + newlyRead + } } private def calcChecksum(block: ByteBuffer): Int = { @@ -205,8 +215,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } private def readBroadcastBlock(): T = Utils.tryOrIOException { - TorrentBroadcast.synchronized { - val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + broadcastCache.synchronized { Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { setConf(SparkEnv.get.conf) From f1609487d39dc4988abc72fd09d8569f91853dc1 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Mon, 28 Jan 2019 13:39:31 +0000 Subject: [PATCH 142/145] fix conflicts --- .circleci/config.yml | 21 +- assembly/pom.xml | 2 +- dev/deps/spark-deps-hadoop-2.7 | 200 - dev/deps/spark-deps-hadoop-3.1 | 218 - dev/deps/spark-deps-hadoop-palantir | 44 +- dev/docker-images/Makefile | 6 +- dev/docker-images/base/Dockerfile | 7 +- dev/docker-images/python/Dockerfile | 8 +- dev/make-distribution.sh | 2 +- dists/hadoop-palantir-bom/pom.xml | 6 +- dists/hadoop-palantir/pom.xml | 6 +- pom.xml | 4 +- project/MimaExcludes.scala | 7 +- python/pyspark/ml/tests.py | 2761 ------- python/pyspark/mllib/tests.py | 1788 ----- python/pyspark/sql/tests.py | 7109 ----------------- python/pyspark/sql/tests/test_arrow.py | 4 +- .../sql/tests/test_pandas_udf_scalar.py | 6 +- .../pyspark/streaming/tests/test_dstream.py | 186 - python/pyspark/tests.py | 2522 ------ .../submit/KubernetesDriverBuilderSuite.scala | 2 + .../src/main/dockerfiles/spark/Dockerfile | 4 - .../src/test/resources/ExpectedDockerfile | 12 +- .../describe-part-after-analyze.sql.out | 24 - .../org/apache/spark/sql/DataFrameSuite.scala | 3 - .../columnar/InMemoryColumnarQuerySuite.scala | 12 +- .../datasources/HadoopFsRelationSuite.scala | 6 +- 27 files changed, 70 insertions(+), 14900 deletions(-) delete mode 100644 dev/deps/spark-deps-hadoop-2.7 delete mode 100644 dev/deps/spark-deps-hadoop-3.1 delete mode 100755 python/pyspark/ml/tests.py delete mode 100644 python/pyspark/mllib/tests.py delete mode 100644 python/pyspark/sql/tests.py delete mode 100644 python/pyspark/tests.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 43f2d58acdf31..86636d6d20314 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 defaults: &defaults docker: - - image: palantirtechnologies/circle-spark-base:0.1.0 + - image: palantirtechnologies/circle-spark-base:0.1.3 resource_class: xlarge environment: &defaults-environment TERM: dumb @@ -128,7 +128,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: # Saves us from recompiling every time... - restore_cache: @@ -147,12 +147,7 @@ jobs: keys: - build-binaries-{{ checksum "build/mvn" }}-{{ checksum "build/sbt" }} - build-binaries- - - run: | - ./build/mvn -T1C -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install \ - | tee -a "/tmp/mvn-install.log" - - store_artifacts: - path: /tmp/mvn-install.log - destination: mvn-install.log + - run: ./build/mvn -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install # Get sbt to run trivially, ensures its launcher is downloaded under build/ - run: ./build/sbt -h || true - save_cache: @@ -300,7 +295,7 @@ jobs: # depends on build-sbt, but we only need the assembly jars <<: *defaults docker: - - image: palantirtechnologies/circle-spark-python:0.1.0 + - image: palantirtechnologies/circle-spark-python:0.1.3 parallelism: 2 steps: - *checkout-code @@ -325,7 +320,7 @@ jobs: # depends on build-sbt, but we only need the assembly jars <<: *defaults docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - attach_workspace: @@ -438,7 +433,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - restore_cache: @@ -458,7 +453,7 @@ jobs: deploy-gradle: <<: *defaults docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - *restore-gradle-wrapper-cache @@ -470,7 +465,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: # This cache contains the whole project after version was set and mvn package was called # Restoring first (and instead of checkout) as mvn versions:set mutates real source code... diff --git a/assembly/pom.xml b/assembly/pom.xml index 68cda458133bd..6f8153e847a47 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -76,7 +76,7 @@
    org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} ${project.version} diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 deleted file mode 100644 index ec7c304c9e36b..0000000000000 --- a/dev/deps/spark-deps-hadoop-2.7 +++ /dev/null @@ -1,200 +0,0 @@ -JavaEWAH-0.3.2.jar -RoaringBitmap-0.5.11.jar -ST4-4.0.4.jar -activation-1.1.1.jar -aircompressor-0.10.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar -antlr4-runtime-4.7.1.jar -aopalliance-1.0.jar -aopalliance-repackaged-2.4.0-b34.jar -apache-log4j-extras-1.2.17.jar -apacheds-i18n-2.0.0-M15.jar -apacheds-kerberos-codec-2.0.0-M15.jar -api-asn1-api-1.0.0-M20.jar -api-util-1.0.0-M20.jar -arpack_combined_all-0.1.jar -arrow-format-0.10.0.jar -arrow-memory-0.10.0.jar -arrow-vector-0.10.0.jar -automaton-1.11-8.jar -avro-1.8.2.jar -avro-ipc-1.8.2.jar -avro-mapred-1.8.2-hadoop2.jar -bonecp-0.8.0.RELEASE.jar -breeze-macros_2.12-0.13.2.jar -breeze_2.12-0.13.2.jar -calcite-avatica-1.2.0-incubating.jar -calcite-core-1.2.0-incubating.jar -calcite-linq4j-1.2.0-incubating.jar -chill-java-0.9.3.jar -chill_2.12-0.9.3.jar -commons-beanutils-1.7.0.jar -commons-beanutils-core-1.8.0.jar -commons-cli-1.2.jar -commons-codec-1.10.jar -commons-collections-3.2.2.jar -commons-compiler-3.0.10.jar -commons-compress-1.8.1.jar -commons-configuration-1.6.jar -commons-crypto-1.0.0.jar -commons-dbcp-1.4.jar -commons-digester-1.8.jar -commons-httpclient-3.1.jar -commons-io-2.4.jar -commons-lang-2.6.jar -commons-lang3-3.8.1.jar -commons-logging-1.1.3.jar -commons-math3-3.4.1.jar -commons-net-3.1.jar -commons-pool-1.5.4.jar -compress-lzf-1.0.3.jar -core-1.1.2.jar -curator-client-2.7.1.jar -curator-framework-2.7.1.jar -curator-recipes-2.7.1.jar -datanucleus-api-jdo-3.2.6.jar -datanucleus-core-3.2.10.jar -datanucleus-rdbms-3.2.9.jar -derby-10.12.1.1.jar -eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar -generex-1.0.1.jar -gson-2.2.4.jar -guava-14.0.1.jar -guice-3.0.jar -guice-servlet-3.0.jar -hadoop-annotations-2.7.4.jar -hadoop-auth-2.7.4.jar -hadoop-client-2.7.4.jar -hadoop-common-2.7.4.jar -hadoop-hdfs-2.7.4.jar -hadoop-mapreduce-client-app-2.7.4.jar -hadoop-mapreduce-client-common-2.7.4.jar -hadoop-mapreduce-client-core-2.7.4.jar -hadoop-mapreduce-client-jobclient-2.7.4.jar -hadoop-mapreduce-client-shuffle-2.7.4.jar -hadoop-yarn-api-2.7.4.jar -hadoop-yarn-client-2.7.4.jar -hadoop-yarn-common-2.7.4.jar -hadoop-yarn-server-common-2.7.4.jar -hadoop-yarn-server-web-proxy-2.7.4.jar -hk2-api-2.4.0-b34.jar -hk2-locator-2.4.0-b34.jar -hk2-utils-2.4.0-b34.jar -hppc-0.7.2.jar -htrace-core-3.1.0-incubating.jar -httpclient-4.5.6.jar -httpcore-4.4.10.jar -ivy-2.4.0.jar -jackson-annotations-2.9.6.jar -jackson-core-2.9.6.jar -jackson-core-asl-1.9.13.jar -jackson-databind-2.9.6.jar -jackson-dataformat-yaml-2.9.6.jar -jackson-jaxrs-1.9.13.jar -jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.9.6.jar -jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.12-2.9.6.jar -jackson-xc-1.9.13.jar -janino-3.0.10.jar -javassist-3.18.1-GA.jar -javax.annotation-api-1.2.jar -javax.inject-1.jar -javax.inject-2.4.0-b34.jar -javax.servlet-api-3.1.0.jar -javax.ws.rs-api-2.0.1.jar -javolution-5.5.1.jar -jaxb-api-2.2.2.jar -jcl-over-slf4j-1.7.16.jar -jdo-api-3.0.1.jar -jersey-client-2.22.2.jar -jersey-common-2.22.2.jar -jersey-container-servlet-2.22.2.jar -jersey-container-servlet-core-2.22.2.jar -jersey-guava-2.22.2.jar -jersey-media-jaxb-2.22.2.jar -jersey-server-2.22.2.jar -jetty-6.1.26.jar -jetty-sslengine-6.1.26.jar -jetty-util-6.1.26.jar -jline-2.14.6.jar -joda-time-2.9.3.jar -jodd-core-3.5.2.jar -jpam-1.1.jar -json4s-ast_2.12-3.5.3.jar -json4s-core_2.12-3.5.3.jar -json4s-jackson_2.12-3.5.3.jar -json4s-scalap_2.12-3.5.3.jar -jsp-api-2.1.jar -jsr305-3.0.0.jar -jta-1.1.jar -jtransforms-2.4.0.jar -jul-to-slf4j-1.7.16.jar -kryo-shaded-4.0.2.jar -kubernetes-client-4.1.0.jar -kubernetes-model-4.1.0.jar -leveldbjni-all-1.8.jar -libfb303-0.9.3.jar -libthrift-0.9.3.jar -log4j-1.2.17.jar -logging-interceptor-3.9.1.jar -lz4-java-1.5.0.jar -machinist_2.12-0.6.1.jar -macro-compat_2.12-1.1.1.jar -mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.1.5.jar -metrics-graphite-3.1.5.jar -metrics-json-3.1.5.jar -metrics-jvm-3.1.5.jar -minlog-1.3.0.jar -netty-3.9.9.Final.jar -netty-all-4.1.30.Final.jar -objenesis-2.5.1.jar -okhttp-3.8.1.jar -okio-1.13.0.jar -opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar -oro-2.0.8.jar -osgi-resource-locator-1.0.1.jar -paranamer-2.8.jar -parquet-column-1.10.0.jar -parquet-common-1.10.0.jar -parquet-encoding-1.10.0.jar -parquet-format-2.4.0.jar -parquet-hadoop-1.10.0.jar -parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.10.0.jar -protobuf-java-2.5.0.jar -py4j-0.10.8.1.jar -pyrolite-4.13.jar -scala-compiler-2.12.7.jar -scala-library-2.12.7.jar -scala-parser-combinators_2.12-1.1.0.jar -scala-reflect-2.12.7.jar -scala-xml_2.12-1.0.5.jar -shapeless_2.12-2.3.2.jar -slf4j-api-1.7.16.jar -slf4j-log4j12-1.7.16.jar -snakeyaml-1.18.jar -snappy-0.2.jar -snappy-java-1.1.7.1.jar -spire-macros_2.12-0.13.0.jar -spire_2.12-0.13.0.jar -stax-api-1.0-2.jar -stax-api-1.0.1.jar -stream-2.7.0.jar -stringtemplate-3.2.1.jar -super-csv-2.2.0.jar -univocity-parsers-2.7.3.jar -validation-api-1.1.0.Final.jar -xbean-asm7-shaded-4.12.jar -xercesImpl-2.9.1.jar -xmlenc-0.52.jar -xz-1.5.jar -zjsonpatch-0.3.0.jar -zookeeper-3.4.6.jar -zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 deleted file mode 100644 index 811febf22940d..0000000000000 --- a/dev/deps/spark-deps-hadoop-3.1 +++ /dev/null @@ -1,218 +0,0 @@ -HikariCP-java7-2.4.12.jar -JavaEWAH-0.3.2.jar -RoaringBitmap-0.5.11.jar -ST4-4.0.4.jar -accessors-smart-1.2.jar -activation-1.1.1.jar -aircompressor-0.10.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar -antlr4-runtime-4.7.1.jar -aopalliance-1.0.jar -aopalliance-repackaged-2.4.0-b34.jar -apache-log4j-extras-1.2.17.jar -arpack_combined_all-0.1.jar -arrow-format-0.10.0.jar -arrow-memory-0.10.0.jar -arrow-vector-0.10.0.jar -automaton-1.11-8.jar -avro-1.8.2.jar -avro-ipc-1.8.2.jar -avro-mapred-1.8.2-hadoop2.jar -bonecp-0.8.0.RELEASE.jar -breeze-macros_2.12-0.13.2.jar -breeze_2.12-0.13.2.jar -calcite-avatica-1.2.0-incubating.jar -calcite-core-1.2.0-incubating.jar -calcite-linq4j-1.2.0-incubating.jar -chill-java-0.9.3.jar -chill_2.12-0.9.3.jar -commons-beanutils-1.9.3.jar -commons-cli-1.2.jar -commons-codec-1.10.jar -commons-collections-3.2.2.jar -commons-compiler-3.0.10.jar -commons-compress-1.8.1.jar -commons-configuration2-2.1.1.jar -commons-crypto-1.0.0.jar -commons-daemon-1.0.13.jar -commons-dbcp-1.4.jar -commons-httpclient-3.1.jar -commons-io-2.4.jar -commons-lang-2.6.jar -commons-lang3-3.8.1.jar -commons-logging-1.1.3.jar -commons-math3-3.4.1.jar -commons-net-3.1.jar -commons-pool-1.5.4.jar -compress-lzf-1.0.3.jar -core-1.1.2.jar -curator-client-2.12.0.jar -curator-framework-2.12.0.jar -curator-recipes-2.12.0.jar -datanucleus-api-jdo-3.2.6.jar -datanucleus-core-3.2.10.jar -datanucleus-rdbms-3.2.9.jar -derby-10.12.1.1.jar -dnsjava-2.1.7.jar -ehcache-3.3.1.jar -eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar -generex-1.0.1.jar -geronimo-jcache_1.0_spec-1.0-alpha-1.jar -gson-2.2.4.jar -guava-14.0.1.jar -guice-4.0.jar -guice-servlet-4.0.jar -hadoop-annotations-3.1.0.jar -hadoop-auth-3.1.0.jar -hadoop-client-3.1.0.jar -hadoop-common-3.1.0.jar -hadoop-hdfs-client-3.1.0.jar -hadoop-mapreduce-client-common-3.1.0.jar -hadoop-mapreduce-client-core-3.1.0.jar -hadoop-mapreduce-client-jobclient-3.1.0.jar -hadoop-yarn-api-3.1.0.jar -hadoop-yarn-client-3.1.0.jar -hadoop-yarn-common-3.1.0.jar -hadoop-yarn-registry-3.1.0.jar -hadoop-yarn-server-common-3.1.0.jar -hadoop-yarn-server-web-proxy-3.1.0.jar -hk2-api-2.4.0-b34.jar -hk2-locator-2.4.0-b34.jar -hk2-utils-2.4.0-b34.jar -hppc-0.7.2.jar -htrace-core4-4.1.0-incubating.jar -httpclient-4.5.6.jar -httpcore-4.4.10.jar -ivy-2.4.0.jar -jackson-annotations-2.9.6.jar -jackson-core-2.9.6.jar -jackson-core-asl-1.9.13.jar -jackson-databind-2.9.6.jar -jackson-dataformat-yaml-2.9.6.jar -jackson-jaxrs-base-2.7.8.jar -jackson-jaxrs-json-provider-2.7.8.jar -jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.9.6.jar -jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.12-2.9.6.jar -janino-3.0.10.jar -javassist-3.18.1-GA.jar -javax.annotation-api-1.2.jar -javax.inject-1.jar -javax.inject-2.4.0-b34.jar -javax.servlet-api-3.1.0.jar -javax.ws.rs-api-2.0.1.jar -javolution-5.5.1.jar -jaxb-api-2.2.11.jar -jcip-annotations-1.0-1.jar -jcl-over-slf4j-1.7.16.jar -jdo-api-3.0.1.jar -jersey-client-2.22.2.jar -jersey-common-2.22.2.jar -jersey-container-servlet-2.22.2.jar -jersey-container-servlet-core-2.22.2.jar -jersey-guava-2.22.2.jar -jersey-media-jaxb-2.22.2.jar -jersey-server-2.22.2.jar -jetty-webapp-9.4.12.v20180830.jar -jetty-xml-9.4.12.v20180830.jar -jline-2.14.6.jar -joda-time-2.9.3.jar -jodd-core-3.5.2.jar -jpam-1.1.jar -json-smart-2.3.jar -json4s-ast_2.12-3.5.3.jar -json4s-core_2.12-3.5.3.jar -json4s-jackson_2.12-3.5.3.jar -json4s-scalap_2.12-3.5.3.jar -jsp-api-2.1.jar -jsr305-3.0.0.jar -jta-1.1.jar -jtransforms-2.4.0.jar -jul-to-slf4j-1.7.16.jar -kerb-admin-1.0.1.jar -kerb-client-1.0.1.jar -kerb-common-1.0.1.jar -kerb-core-1.0.1.jar -kerb-crypto-1.0.1.jar -kerb-identity-1.0.1.jar -kerb-server-1.0.1.jar -kerb-simplekdc-1.0.1.jar -kerb-util-1.0.1.jar -kerby-asn1-1.0.1.jar -kerby-config-1.0.1.jar -kerby-pkix-1.0.1.jar -kerby-util-1.0.1.jar -kerby-xdr-1.0.1.jar -kryo-shaded-4.0.2.jar -kubernetes-client-4.1.0.jar -kubernetes-model-4.1.0.jar -leveldbjni-all-1.8.jar -libfb303-0.9.3.jar -libthrift-0.9.3.jar -log4j-1.2.17.jar -logging-interceptor-3.9.1.jar -lz4-java-1.5.0.jar -machinist_2.12-0.6.1.jar -macro-compat_2.12-1.1.1.jar -mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.1.5.jar -metrics-graphite-3.1.5.jar -metrics-json-3.1.5.jar -metrics-jvm-3.1.5.jar -minlog-1.3.0.jar -mssql-jdbc-6.2.1.jre7.jar -netty-3.9.9.Final.jar -netty-all-4.1.30.Final.jar -nimbus-jose-jwt-4.41.1.jar -objenesis-2.5.1.jar -okhttp-2.7.5.jar -okhttp-3.8.1.jar -okio-1.13.0.jar -opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar -oro-2.0.8.jar -osgi-resource-locator-1.0.1.jar -paranamer-2.8.jar -parquet-column-1.10.0.jar -parquet-common-1.10.0.jar -parquet-encoding-1.10.0.jar -parquet-format-2.4.0.jar -parquet-hadoop-1.10.0.jar -parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.10.0.jar -protobuf-java-2.5.0.jar -py4j-0.10.8.1.jar -pyrolite-4.13.jar -re2j-1.1.jar -scala-compiler-2.12.7.jar -scala-library-2.12.7.jar -scala-parser-combinators_2.12-1.1.0.jar -scala-reflect-2.12.7.jar -scala-xml_2.12-1.0.5.jar -shapeless_2.12-2.3.2.jar -slf4j-api-1.7.16.jar -slf4j-log4j12-1.7.16.jar -snakeyaml-1.18.jar -snappy-0.2.jar -snappy-java-1.1.7.1.jar -spire-macros_2.12-0.13.0.jar -spire_2.12-0.13.0.jar -stax-api-1.0.1.jar -stax2-api-3.1.4.jar -stream-2.7.0.jar -stringtemplate-3.2.1.jar -super-csv-2.2.0.jar -token-provider-1.0.1.jar -univocity-parsers-2.7.3.jar -validation-api-1.1.0.Final.jar -woodstox-core-5.0.3.jar -xbean-asm7-shaded-4.12.jar -xz-1.5.jar -zjsonpatch-0.3.0.jar -zookeeper-3.4.9.jar -zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-palantir b/dev/deps/spark-deps-hadoop-palantir index 7253e4920b110..0d881c6cdab73 100644 --- a/dev/deps/spark-deps-hadoop-palantir +++ b/dev/deps/spark-deps-hadoop-palantir @@ -2,7 +2,7 @@ HikariCP-java7-2.4.12.jar RoaringBitmap-0.7.16.jar activation-1.1.1.jar aircompressor-0.10.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.5.0-b32.jar apacheds-i18n-2.0.0-M15.jar @@ -21,10 +21,10 @@ avro-mapred-1.8.2-hadoop2.jar aws-java-sdk-bundle-1.11.201.jar azure-keyvault-core-0.8.0.jar azure-storage-5.4.0.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar classmate-1.1.0.jar commons-beanutils-1.9.3.jar commons-beanutils-core-1.8.0.jar @@ -32,14 +32,14 @@ commons-cli-1.2.jar commons-codec-1.11.jar commons-collections-3.2.2.jar commons-compiler-3.0.10.jar -commons-compress-1.8.1.jar +commons-compress-1.18.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-digester-1.8.jar commons-httpclient-3.1.jar commons-io-2.6.jar commons-lang-2.6.jar -commons-lang3-3.8.jar +commons-lang3-3.8.1.jar commons-logging-1.2.jar commons-math3-3.6.1.jar commons-net-3.6.jar @@ -105,7 +105,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-afterburner-2.9.7.jar jackson-module-jaxb-annotations-2.9.7.jar jackson-module-paranamer-2.9.7.jar -jackson-module-scala_2.11-2.9.7.jar +jackson-module-scala_2.12-2.9.7.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.20.0-GA.jar @@ -132,10 +132,10 @@ jetty-util-6.1.26.jar jetty-util-9.4.12.v20180830.jar joda-time-2.10.jar json-smart-1.3.1.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.2.jar jtransforms-2.4.0.jar @@ -148,8 +148,8 @@ leveldbjni-all-1.8.jar log4j-1.2.17.jar logging-interceptor-3.11.0.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar metrics-core-3.2.6.jar metrics-graphite-3.2.6.jar metrics-influxdb-1.2.2.jar @@ -181,18 +181,18 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar safe-logging-1.5.1.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.25.jar slf4j-log4j12-1.7.25.jar snakeyaml-1.23.jar snappy-java-1.1.7.2.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0-2.jar stax2-api-3.1.4.jar stream-2.9.6.jar @@ -203,5 +203,5 @@ xbean-asm7-shaded-4.12.jar xmlenc-0.52.jar xz-1.5.jar zjsonpatch-0.3.0.jar -zookeeper-3.4.6.jar +zookeeper-3.4.7.jar zstd-jni-1.3.5-3.jar diff --git a/dev/docker-images/Makefile b/dev/docker-images/Makefile index ed3e3a5ee7687..168d386770655 100644 --- a/dev/docker-images/Makefile +++ b/dev/docker-images/Makefile @@ -17,9 +17,9 @@ .PHONY: all publish base python r -BASE_IMAGE_NAME = palantirtechnologies/circle-spark-base:0.1.0 -PYTHON_IMAGE_NAME = palantirtechnologies/circle-spark-python:0.1.0 -R_IMAGE_NAME = palantirtechnologies/circle-spark-r:0.1.0 +BASE_IMAGE_NAME = palantirtechnologies/circle-spark-base:0.1.3 +PYTHON_IMAGE_NAME = palantirtechnologies/circle-spark-python:0.1.3 +R_IMAGE_NAME = palantirtechnologies/circle-spark-r:0.1.3 all: base python r diff --git a/dev/docker-images/base/Dockerfile b/dev/docker-images/base/Dockerfile index 0e84ec665fcd7..61293349a0db5 100644 --- a/dev/docker-images/base/Dockerfile +++ b/dev/docker-images/base/Dockerfile @@ -31,7 +31,7 @@ RUN mkdir -p /usr/share/man/man1 \ git \ locales sudo openssh-client ca-certificates tar gzip parallel \ net-tools netcat unzip zip bzip2 gnupg curl wget \ - openjdk-8-jdk rsync pandoc pandoc-citeproc flake8 \ + openjdk-8-jdk rsync pandoc pandoc-citeproc flake8 tzdata \ && rm -rf /var/lib/apt/lists/* # If you update java, make sure this aligns @@ -42,8 +42,9 @@ ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 RUN ln -sf /usr/share/zoneinfo/Etc/UTC /etc/localtime # Use unicode -RUN locale-gen C.UTF-8 || true -ENV LANG=C.UTF-8 +RUN locale-gen en_US.UTF-8 +ENV LANG=en_US.UTF-8 +ENV TZ=UTC # install jq RUN JQ_URL="https://circle-downloads.s3.amazonaws.com/circleci-images/cache/linux-amd64/jq-latest" \ diff --git a/dev/docker-images/python/Dockerfile b/dev/docker-images/python/Dockerfile index cb44b373617da..bb039b7c5c8b6 100644 --- a/dev/docker-images/python/Dockerfile +++ b/dev/docker-images/python/Dockerfile @@ -34,12 +34,8 @@ RUN mkdir -p $(pyenv root)/versions \ RUN pyenv global our-miniconda/envs/python2 our-miniconda/envs/python3 \ && pyenv rehash -RUN $CONDA_BIN install -y -n python2 -c anaconda -c conda-forge xmlrunner \ - && $CONDA_BIN install -y -n python3 -c anaconda -c conda-forge xmlrunner \ - && $CONDA_BIN clean --all - # Expose pyenv globally ENV PATH=$CIRCLE_HOME/.pyenv/shims:$PATH -RUN PYENV_VERSION=our-miniconda/envs/python2 $CIRCLE_HOME/.pyenv/shims/pip install unishark \ - && PYENV_VERSION=our-miniconda/envs/python3 $CIRCLE_HOME/.pyenv/shims/pip install unishark +RUN PYENV_VERSION=our-miniconda/envs/python2 $CIRCLE_HOME/.pyenv/shims/pip install unishark unittest-xml-reporting \ + && PYENV_VERSION=our-miniconda/envs/python3 $CIRCLE_HOME/.pyenv/shims/pip install unishark unittest-xml-reporting diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index e4405d89a1c17..1bde61b998ca5 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -173,7 +173,7 @@ fi # Normal quoting tricks don't work. # See: http://mywiki.wooledge.org/BashFAQ/050 if [[ -z "$DONT_BUILD" ]]; then - BUILD_COMMAND=("$MVN" -T 1C $MAYBE_CLEAN package -DskipTests $@) + BUILD_COMMAND=("$MVN" $MAYBE_CLEAN package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." diff --git a/dists/hadoop-palantir-bom/pom.xml b/dists/hadoop-palantir-bom/pom.xml index 4dcbb5149f750..523997699e9e6 100644 --- a/dists/hadoop-palantir-bom/pom.xml +++ b/dists/hadoop-palantir-bom/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-dist_2.11-hadoop-palantir-bom + spark-dist_2.12-hadoop-palantir-bom Spark Project Dist Palantir Hadoop (BOM) http://spark.apache.org/ pom @@ -107,7 +107,7 @@ org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} ${project.version} diff --git a/dists/hadoop-palantir/pom.xml b/dists/hadoop-palantir/pom.xml index d3158b1bbef10..a3bb85e70f844 100644 --- a/dists/hadoop-palantir/pom.xml +++ b/dists/hadoop-palantir/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-dist_2.11-hadoop-palantir-bom + spark-dist_2.12-hadoop-palantir-bom 3.0.0-SNAPSHOT ../hadoop-palantir-bom/pom.xml - spark-dist_2.11-hadoop-palantir + spark-dist_2.12-hadoop-palantir Spark Project Dist Palantir Hadoop http://spark.apache.org/ pom @@ -57,7 +57,7 @@ org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} org.apache.spark diff --git a/pom.xml b/pom.xml index 413d95ff431a2..0654831b70582 100644 --- a/pom.xml +++ b/pom.xml @@ -116,14 +116,14 @@ 1.8 ${java.version} ${java.version} - 3.5.4 + 3.6.0 spark 1.7.25 1.2.17 2.7.4 2.5.0 ${hadoop.version} - 3.4.6 + 3.4.7 2.7.1 org.spark-project.hive diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index da64b6e38a21c..eacda42813130 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -227,7 +227,12 @@ object MimaExcludes { // [SPARK-26141] Enable custom metrics implementation in shuffle write // Following are Java private classes ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), + + // SafeLogging after MimaUpgrade + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.broadcast.Broadcast.initializeLogIfNecessary$default$2") ) // Exclude rules for 2.4.x diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py deleted file mode 100755 index 8c4f02dd724b4..0000000000000 --- a/python/pyspark/ml/tests.py +++ /dev/null @@ -1,2761 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for MLlib Python DataFrame-based APIs. -""" -import sys - -import unishark - -if sys.version > '3': - xrange = range - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from shutil import rmtree -import tempfile -import array as pyarray -import numpy as np -from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros -import inspect -import py4j - -from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer -from pyspark.ml.classification import * -from pyspark.ml.clustering import * -from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ - MulticlassClassificationEvaluator, RegressionEvaluator -from pyspark.ml.feature import * -from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.image import ImageSchema -from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors -from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed -from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ - LinearRegression -from pyspark.ml.stat import ChiSquareTest -from pyspark.ml.tuning import * -from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession, HiveContext -from pyspark.sql.functions import rand -from pyspark.sql.types import DoubleType, IntegerType -from pyspark.storagelevel import * -from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class SparkSessionTestCase(PySparkTestCase): - @classmethod - def setUpClass(cls): - PySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.stop() - - -class MockDataset(DataFrame): - - def __init__(self): - self.index = 0 - - -class HasFake(Params): - - def __init__(self): - super(HasFake, self).__init__() - self.fake = Param(self, "fake", "fake param") - - def getFake(self): - return self.getOrDefault(self.fake) - - -class MockTransformer(Transformer, HasFake): - - def __init__(self): - super(MockTransformer, self).__init__() - self.dataset_index = None - - def _transform(self, dataset): - self.dataset_index = dataset.index - dataset.index += 1 - return dataset - - -class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): - - shift = Param(Params._dummy(), "shift", "The amount by which to shift " + - "data in a DataFrame", - typeConverter=TypeConverters.toFloat) - - def __init__(self, shiftVal=1): - super(MockUnaryTransformer, self).__init__() - self._setDefault(shift=1) - self._set(shift=shiftVal) - - def getShift(self): - return self.getOrDefault(self.shift) - - def setShift(self, shift): - self._set(shift=shift) - - def createTransformFunc(self): - shiftVal = self.getShift() - return lambda x: x + shiftVal - - def outputDataType(self): - return DoubleType() - - def validateInputType(self, inputType): - if inputType != DoubleType(): - raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Double.") - - -class MockEstimator(Estimator, HasFake): - - def __init__(self): - super(MockEstimator, self).__init__() - self.dataset_index = None - - def _fit(self, dataset): - self.dataset_index = dataset.index - model = MockModel() - self._copyValues(model) - return model - - -class MockModel(MockTransformer, Model, HasFake): - pass - - -class JavaWrapperMemoryTests(SparkSessionTestCase): - - def test_java_object_gets_detached(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - - model = lr.fit(df) - summary = model.summary - - self.assertIsInstance(model, JavaWrapper) - self.assertIsInstance(summary, JavaWrapper) - self.assertIsInstance(model, JavaParams) - self.assertNotIsInstance(summary, JavaParams) - - error_no_object = 'Target Object ID does not exist for this gateway' - - self.assertIn("LinearRegression_", model._java_obj.toString()) - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - model.__del__() - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - try: - summary.__del__() - except: - pass - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - summary._java_obj.toString() - - -class ParamTypeConversionTests(PySparkTestCase): - """ - Test that param type conversion happens. - """ - - def test_int(self): - lr = LogisticRegression(maxIter=5.0) - self.assertEqual(lr.getMaxIter(), 5) - self.assertTrue(type(lr.getMaxIter()) == int) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) - - def test_float(self): - lr = LogisticRegression(tol=1) - self.assertEqual(lr.getTol(), 1.0) - self.assertTrue(type(lr.getTol()) == float) - self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) - - def test_vector(self): - ewp = ElementwiseProduct(scalingVec=[1, 3]) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) - ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) - self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) - - def test_list(self): - l = [0, 1] - for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), - range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]: - converted = TypeConverters.toList(lst_like) - self.assertEqual(type(converted), list) - self.assertListEqual(converted, l) - - def test_list_int(self): - for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), - SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), - pyarray.array('d', [1.0, 2.0])]: - vs = VectorSlicer(indices=indices) - self.assertListEqual(vs.getIndices(), [1, 2]) - self.assertTrue(all([type(v) == int for v in vs.getIndices()])) - self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) - - def test_list_float(self): - b = Bucketizer(splits=[1, 4]) - self.assertEqual(b.getSplits(), [1.0, 4.0]) - self.assertTrue(all([type(v) == float for v in b.getSplits()])) - self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) - - def test_list_string(self): - for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: - idx_to_string = IndexToString(labels=labels) - self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) - self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) - - def test_string(self): - lr = LogisticRegression() - for col in ['features', u'features', np.str_('features')]: - lr.setFeaturesCol(col) - self.assertEqual(lr.getFeaturesCol(), 'features') - self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) - - def test_bool(self): - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - - -class PipelineTests(PySparkTestCase): - - def test_pipeline(self): - dataset = MockDataset() - estimator0 = MockEstimator() - transformer1 = MockTransformer() - estimator2 = MockEstimator() - transformer3 = MockTransformer() - pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) - pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - model0, transformer1, model2, transformer3 = pipeline_model.stages - self.assertEqual(0, model0.dataset_index) - self.assertEqual(0, model0.getFake()) - self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.getFake()) - self.assertEqual(2, dataset.index) - self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") - self.assertIsNone(transformer3.dataset_index, - "The last transformer shouldn't be called in fit.") - dataset = pipeline_model.transform(dataset) - self.assertEqual(2, model0.dataset_index) - self.assertEqual(3, transformer1.dataset_index) - self.assertEqual(4, model2.dataset_index) - self.assertEqual(5, transformer3.dataset_index) - self.assertEqual(6, dataset.index) - - def test_identity_pipeline(self): - dataset = MockDataset() - - def doTransform(pipeline): - pipeline_model = pipeline.fit(dataset) - return pipeline_model.transform(dataset) - # check that empty pipeline did not perform any transformation - self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) - # check that failure to set stages param will raise KeyError for missing param - self.assertRaises(KeyError, lambda: doTransform(Pipeline())) - - -class TestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(TestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(OtherTestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") - - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") - - -class ParamTests(SparkSessionTestCase): - - def test_copy_new_parent(self): - testParams = TestParams() - # Copying an instantiated param should fail - with self.assertRaises(ValueError): - testParams.maxIter._copy_new_parent(testParams) - # Copying a dummy param should succeed - TestParams.maxIter._copy_new_parent(testParams) - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_param(self): - testParams = TestParams() - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_hasparam(self): - testParams = TestParams() - self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) - self.assertFalse(testParams.hasParam("notAParameter")) - self.assertTrue(testParams.hasParam(u"maxIter")) - - def test_resolveparam(self): - testParams = TestParams() - self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) - self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) - - self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) - if sys.version_info[0] >= 3: - # In Python 3, it is allowed to get/set attributes with non-ascii characters. - e_cls = AttributeError - else: - e_cls = UnicodeEncodeError - self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) - - def test_params(self): - testParams = TestParams() - maxIter = testParams.maxIter - inputCol = testParams.inputCol - seed = testParams.seed - - params = testParams.params - self.assertEqual(params, [inputCol, maxIter, seed]) - - self.assertTrue(testParams.hasParam(maxIter.name)) - self.assertTrue(testParams.hasDefault(maxIter)) - self.assertFalse(testParams.isSet(maxIter)) - self.assertTrue(testParams.isDefined(maxIter)) - self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) - self.assertTrue(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 100) - - self.assertTrue(testParams.hasParam(inputCol.name)) - self.assertFalse(testParams.hasDefault(inputCol)) - self.assertFalse(testParams.isSet(inputCol)) - self.assertFalse(testParams.isDefined(inputCol)) - with self.assertRaises(KeyError): - testParams.getInputCol() - - otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + - "set raises an error for a non-member parameter.", - typeConverter=TypeConverters.toString) - with self.assertRaises(ValueError): - testParams.set(otherParam, "value") - - # Since the default is normally random, set it to a known number for debug str - testParams._setDefault(seed=41) - testParams.setSeed(43) - - self.assertEqual( - testParams.explainParams(), - "\n".join(["inputCol: input column name. (undefined)", - "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", - "seed: random seed. (default: 41, current: 43)"])) - - def test_kmeans_param(self): - algo = KMeans() - self.assertEqual(algo.getInitMode(), "k-means||") - algo.setK(10) - self.assertEqual(algo.getK(), 10) - algo.setInitSteps(10) - self.assertEqual(algo.getInitSteps(), 10) - self.assertEqual(algo.getDistanceMeasure(), "euclidean") - algo.setDistanceMeasure("cosine") - self.assertEqual(algo.getDistanceMeasure(), "cosine") - - def test_hasseed(self): - noSeedSpecd = TestParams() - withSeedSpecd = TestParams(seed=42) - other = OtherTestParams() - # Check that we no longer use 42 as the magic number - self.assertNotEqual(noSeedSpecd.getSeed(), 42) - origSeed = noSeedSpecd.getSeed() - # Check that we only compute the seed once - self.assertEqual(noSeedSpecd.getSeed(), origSeed) - # Check that a specified seed is honored - self.assertEqual(withSeedSpecd.getSeed(), 42) - # Check that a different class has a different seed - self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) - - def test_param_property_error(self): - param_store = HasThrowableProperty() - self.assertRaises(RuntimeError, lambda: param_store.test_property) - params = param_store.params # should not invoke the property 'test_property' - self.assertEqual(len(params), 1) - - def test_word2vec_param(self): - model = Word2Vec().setWindowSize(6) - # Check windowSize is set properly - self.assertEqual(model.getWindowSize(), 6) - - def test_copy_param_extras(self): - tp = TestParams(seed=42) - extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} - tp_copy = tp.copy(extra=extra) - self.assertEqual(tp.uid, tp_copy.uid) - self.assertEqual(tp.params, tp_copy.params) - for k, v in extra.items(): - self.assertTrue(tp_copy.isDefined(k)) - self.assertEqual(tp_copy.getOrDefault(k), v) - copied_no_extra = {} - for k, v in tp_copy._paramMap.items(): - if k not in extra: - copied_no_extra[k] = v - self.assertEqual(tp._paramMap, copied_no_extra) - self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) - - def test_logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegression - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - - def test_preserve_set_state(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - self.assertFalse(binarizer.isSet("threshold")) - binarizer.transform(dataset) - binarizer._transfer_params_from_java() - self.assertFalse(binarizer.isSet("threshold"), - "Params not explicitly set should remain unset after transform") - - def test_default_params_transferred(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - # intentionally change the pyspark default, but don't set it - binarizer._defaultParamMap[binarizer.outputCol] = "my_default" - result = binarizer.transform(dataset).select("my_default").collect() - self.assertFalse(binarizer.isSet(binarizer.outputCol)) - self.assertEqual(result[0][0], 1.0) - - @staticmethod - def check_params(test_self, py_stage, check_params_exist=True): - """ - Checks common requirements for Params.params: - - set of params exist in Java and Python and are ordered by names - - param parent has the same UID as the object's UID - - default param value from Java matches value in Python - - optionally check if all params from Java also exist in Python - """ - py_stage_str = "%s %s" % (type(py_stage), py_stage) - if not hasattr(py_stage, "_to_java"): - return - java_stage = py_stage._to_java() - if java_stage is None: - return - test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) - if check_params_exist: - param_names = [p.name for p in py_stage.params] - java_params = list(java_stage.params()) - java_param_names = [jp.name() for jp in java_params] - test_self.assertEqual( - param_names, sorted(java_param_names), - "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" - % (py_stage_str, java_param_names, param_names)) - for p in py_stage.params: - test_self.assertEqual(p.parent, py_stage.uid) - java_param = java_stage.getParam(p.name) - py_has_default = py_stage.hasDefault(p) - java_has_default = java_stage.hasDefault(java_param) - test_self.assertEqual(py_has_default, java_has_default, - "Default value mismatch of param %s for Params %s" - % (p.name, str(py_stage))) - if py_has_default: - if p.name == "seed": - continue # Random seeds between Spark and PySpark are different - java_default = _java2py(test_self.sc, - java_stage.clear(java_param).getOrDefault(java_param)) - py_stage._clear(p) - py_default = py_stage.getOrDefault(p) - # equality test for NaN is always False - if isinstance(java_default, float) and np.isnan(java_default): - java_default = "NaN" - py_default = "NaN" if np.isnan(py_default) else "not NaN" - test_self.assertEqual( - java_default, py_default, - "Java default %s != python default %s of param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - - -class EvaluatorTests(SparkSessionTestCase): - - def test_java_params(self): - """ - This tests a bug fixed by SPARK-18274 which causes multiple copies - of a Params instance in Python to be linked to the same Java instance. - """ - evaluator = RegressionEvaluator(metricName="r2") - df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) - evaluator.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) - evaluator.evaluate(df) - evaluatorCopy.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") - - def test_clustering_evaluator_with_cosine_distance(self): - featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), - [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), - ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) - dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) - evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") - self.assertEqual(evaluator.getDistanceMeasure(), "cosine") - self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) - - -class FeatureTests(SparkSessionTestCase): - - def test_binarizer(self): - b0 = Binarizer() - self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) - self.assertTrue(all([~b0.isSet(p) for p in b0.params])) - self.assertTrue(b0.hasDefault(b0.threshold)) - self.assertEqual(b0.getThreshold(), 0.0) - b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) - self.assertTrue(all([b0.isSet(p) for p in b0.params])) - self.assertEqual(b0.getThreshold(), 1.0) - self.assertEqual(b0.getInputCol(), "input") - self.assertEqual(b0.getOutputCol(), "output") - - b0c = b0.copy({b0.threshold: 2.0}) - self.assertEqual(b0c.uid, b0.uid) - self.assertListEqual(b0c.params, b0.params) - self.assertEqual(b0c.getThreshold(), 2.0) - - b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") - self.assertNotEqual(b1.uid, b0.uid) - self.assertEqual(b1.getThreshold(), 2.0) - self.assertEqual(b1.getInputCol(), "input") - self.assertEqual(b1.getOutputCol(), "output") - - def test_idf(self): - dataset = self.spark.createDataFrame([ - (DenseVector([1.0, 2.0]),), - (DenseVector([0.0, 1.0]),), - (DenseVector([3.0, 0.2]),)], ["tf"]) - idf0 = IDF(inputCol="tf") - self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) - idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) - self.assertEqual(idf0m.uid, idf0.uid, - "Model should inherit the UID from its parent estimator.") - output = idf0m.transform(dataset) - self.assertIsNotNone(output.head().idf) - # Test that parameters transferred to Python Model - ParamTests.check_params(self, idf0m) - - def test_ngram(self): - dataset = self.spark.createDataFrame([ - Row(input=["a", "b", "c", "d", "e"])]) - ngram0 = NGram(n=4, inputCol="input", outputCol="output") - self.assertEqual(ngram0.getN(), 4) - self.assertEqual(ngram0.getInputCol(), "input") - self.assertEqual(ngram0.getOutputCol(), "output") - transformedDF = ngram0.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) - - def test_stopwordsremover(self): - dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) - stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") - # Default - self.assertEqual(stopWordRemover.getInputCol(), "input") - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["panda"]) - self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) - # Custom - stopwords = ["panda"] - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getInputCol(), "input") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - # with locale - stopwords = ["BELKİ"] - dataset = self.spark.createDataFrame([Row(input=["belki"])]) - stopWordRemover.setStopWords(stopwords).setLocale("tr") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - - def test_count_vectorizer_with_binary(self): - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) - cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") - model = cv.fit(dataset) - - transformedList = model.transform(dataset).select("features", "expected").collect() - - for r in transformedList: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_with_maxDF(self): - dataset = self.spark.createDataFrame([ - (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), - (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - cv = CountVectorizer(inputCol="words", outputCol="features") - model1 = cv.setMaxDF(3).fit(dataset) - self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) - - transformedList1 = model1.transform(dataset).select("features", "expected").collect() - - for r in transformedList1: - feature, expected = r - self.assertEqual(feature, expected) - - model2 = cv.setMaxDF(0.75).fit(dataset) - self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) - - transformedList2 = model2.transform(dataset).select("features", "expected").collect() - - for r in transformedList2: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_from_vocab(self): - model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", - outputCol="features", minTF=2) - self.assertEqual(model.vocabulary, ["a", "b", "c"]) - self.assertEqual(model.getMinTF(), 2) - - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), - (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - - transformed_list = model.transform(dataset).select("features", "expected").collect() - - for r in transformed_list: - feature, expected = r - self.assertEqual(feature, expected) - - # Test an empty vocabulary - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): - CountVectorizerModel.from_vocabulary([], inputCol="words") - - # Test model with default settings can transform - model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") - transformed_list = model_default.transform(dataset)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 3) - - def test_rformula_force_index_label(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - # Does not index label by default since it's numeric type. - rf = RFormula(formula="y ~ x + s") - model = rf.fit(df) - transformedDF = model.transform(df) - self.assertEqual(transformedDF.head().label, 1.0) - # Force to index label. - rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) - model2 = rf2.fit(df) - transformedDF2 = model2.transform(df) - self.assertEqual(transformedDF2.head().label, 0.0) - - def test_rformula_string_indexer_order_type(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") - self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') - transformedDF = rf.fit(df).transform(df) - observed = transformedDF.select("features").collect() - expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] - for i in range(0, len(expected)): - self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) - - def test_string_indexer_handle_invalid(self): - df = self.spark.createDataFrame([ - (0, "a"), - (1, "d"), - (2, None)], ["id", "label"]) - - si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", - stringOrderType="alphabetAsc") - model1 = si1.fit(df) - td1 = model1.transform(df) - actual1 = td1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] - self.assertEqual(actual1, expected1) - - si2 = si1.setHandleInvalid("skip") - model2 = si2.fit(df) - td2 = model2.transform(df) - actual2 = td2.select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] - self.assertEqual(actual2, expected2) - - def test_string_indexer_from_labels(self): - model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", - outputCol="indexed", handleInvalid="keep") - self.assertEqual(model.labels, ["a", "b", "c"]) - - df1 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, None), - (3, "b"), - (4, "b")], ["id", "label"]) - - result1 = model.transform(df1) - actual1 = result1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), - Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] - self.assertEqual(actual1, expected1) - - model_empty_labels = StringIndexerModel.from_labels( - [], inputCol="label", outputCol="indexed", handleInvalid="keep") - actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), - Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] - self.assertEqual(actual2, expected2) - - # Test model with default settings can transform - model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") - df2 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, "b"), - (3, "b"), - (4, "b")], ["id", "label"]) - transformed_list = model_default.transform(df2)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 5) - - def test_vector_size_hint(self): - df = self.spark.createDataFrame( - [(0, Vectors.dense([0.0, 10.0, 0.5])), - (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), - (2, Vectors.dense([2.0, 12.0]))], - ["id", "vector"]) - - sizeHint = VectorSizeHint( - inputCol="vector", - handleInvalid="skip") - sizeHint.setSize(3) - self.assertEqual(sizeHint.getSize(), 3) - - output = sizeHint.transform(df).head().vector - expected = DenseVector([0.0, 10.0, 0.5]) - self.assertEqual(output, expected) - - -class HasInducedError(Params): - - def __init__(self): - super(HasInducedError, self).__init__() - self.inducedError = Param(self, "inducedError", - "Uniformly-distributed error added to feature") - - def getInducedError(self): - return self.getOrDefault(self.inducedError) - - -class InducedErrorModel(Model, HasInducedError): - - def __init__(self): - super(InducedErrorModel, self).__init__() - - def _transform(self, dataset): - return dataset.withColumn("prediction", - dataset.feature + (rand(0) * self.getInducedError())) - - -class InducedErrorEstimator(Estimator, HasInducedError): - - def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() - self._set(inducedError=inducedError) - - def _fit(self, dataset): - model = InducedErrorModel() - self._copyValues(model) - return model - - -class CrossValidatorTests(SparkSessionTestCase): - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvCopied = cv.copy() - self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) - - cvModel = cv.fit(dataset) - cvModelCopied = cvModel.copy() - for index in range(len(cvModel.avgMetrics)): - self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) - < 0.0001) - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - - def test_param_grid_type_coercion(self): - lr = LogisticRegression(maxIter=10) - paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() - for param in paramGrid: - for v in param.values(): - assert(type(v) == float) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for CrossValidator will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - lrModel = cvModel.bestModel - - cvModelPath = temp_path + "/cvModel" - lrModel.save(cvModelPath) - loadedLrModel = LogisticRegressionModel.load(cvModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cv.setParallelism(1) - cvSerialModel = cv.fit(dataset) - cv.setParallelism(2) - cvParallelModel = cv.fit(dataset) - self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - numFolds = 3 - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - numFolds=numFolds, collectSubModels=True) - - def checkSubModels(subModels): - self.assertEqual(len(subModels), numFolds) - for i in range(numFolds): - self.assertEqual(len(subModels[i]), len(grid)) - - cvModel = cv.fit(dataset) - checkSubModels(cvModel.subModels) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testCrossValidatorSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - cvModel.save(savingPathWithSubModels) - cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - checkSubModels(cvModel3.subModels) - cvModel4 = cvModel3.copy() - checkSubModels(cvModel4.subModels) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - self.assertEqual(cvModel2.subModels, None) - - for i in range(numFolds): - for j in range(len(grid)): - self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) - - def test_save_load_nested_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - - originalParamMap = cv.getEstimatorParamMaps() - loadedParamMap = loadedCV.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - -class TrainValidationSplitTests(SparkSessionTestCase): - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(0.0, min(validationMetrics)) - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(1.0, max(validationMetrics)) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - lrModel = tvsModel.bestModel - - tvsModelPath = temp_path + "/tvsModel" - lrModel.save(tvsModelPath) - loadedLrModel = LogisticRegressionModel.load(tvsModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvs.setParallelism(1) - tvsSerialModel = tvs.fit(dataset) - tvs.setParallelism(2) - tvsParallelModel = tvs.fit(dataset) - self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - collectSubModels=True) - tvsModel = tvs.fit(dataset) - self.assertEqual(len(tvsModel.subModels), len(grid)) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testTrainValidationSplitSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - tvsModel.save(savingPathWithSubModels) - tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - self.assertEqual(len(tvsModel3.subModels), len(grid)) - tvsModel4 = tvsModel3.copy() - self.assertEqual(len(tvsModel4.subModels), len(grid)) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - self.assertEqual(tvsModel2.subModels, None) - - for i in range(len(grid)): - self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) - - def test_save_load_nested_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - - originalParamMap = tvs.getEstimatorParamMaps() - loadedParamMap = loadedTvs.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsCopied = tvs.copy() - tvsModelCopied = tvsModel.copy() - - self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, - "Copied TrainValidationSplit has the same uid of Estimator") - - self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) - self.assertEqual(len(tvsModel.validationMetrics), - len(tvsModelCopied.validationMetrics), - "Copied validationMetrics has the same size of the original") - for index in range(len(tvsModel.validationMetrics)): - self.assertEqual(tvsModel.validationMetrics[index], - tvsModelCopied.validationMetrics[index]) - - -class PersistenceTest(SparkSessionTestCase): - - def test_linear_regression(self): - lr = LinearRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/lr" - lr.save(lr_path) - lr2 = LinearRegression.load(lr_path) - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(type(lr.uid), type(lr2.uid)) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LinearRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_linear_regression_pmml_basic(self): - # Most of the validation is done in the Scala side, here we just check - # that we output text rather than parquet (e.g. that the format flag - # was respected). - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1) - model = lr.fit(df) - path = tempfile.mkdtemp() - lr_path = path + "/lr-pmml" - model.write().format("pmml").save(lr_path) - pmml_text_list = self.sc.textFile(lr_path).collect() - pmml_text = "\n".join(pmml_text_list) - self.assertIn("Apache Spark", pmml_text) - self.assertIn("PMML", pmml_text) - - def test_logistic_regression(self): - lr = LogisticRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/logreg" - lr.save(lr_path) - lr2 = LogisticRegression.load(lr_path) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LogisticRegression instance uid (%s) " - "did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LogisticRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def _compare_params(self, m1, m2, param): - """ - Compare 2 ML Params instances for the given param, and assert both have the same param value - and parent. The param must be a parameter of m1. - """ - # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. - if m1.isDefined(param): - paramValue1 = m1.getOrDefault(param) - paramValue2 = m2.getOrDefault(m2.getParam(param.name)) - if isinstance(paramValue1, Params): - self._compare_pipelines(paramValue1, paramValue2) - else: - self.assertEqual(paramValue1, paramValue2) # for general types param - # Assert parents are equal - self.assertEqual(param.parent, m2.getParam(param.name).parent) - else: - # If m1 is not defined param, then m2 should not, too. See SPARK-14931. - self.assertFalse(m2.isDefined(m2.getParam(param.name))) - - def _compare_pipelines(self, m1, m2): - """ - Compare 2 ML types, asserting that they are equivalent. - This currently supports: - - basic types - - Pipeline, PipelineModel - - OneVsRest, OneVsRestModel - This checks: - - uid - - type - - Param values and parents - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams) or isinstance(m1, Transformer): - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - self._compare_params(m1, m2, p) - elif isinstance(m1, Pipeline): - self.assertEqual(len(m1.getStages()), len(m2.getStages())) - for s1, s2 in zip(m1.getStages(), m2.getStages()): - self._compare_pipelines(s1, s2) - elif isinstance(m1, PipelineModel): - self.assertEqual(len(m1.stages), len(m2.stages)) - for s1, s2 in zip(m1.stages, m2.stages): - self._compare_pipelines(s1, s2) - elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): - for p in m1.params: - self._compare_params(m1, m2, p) - if isinstance(m1, OneVsRestModel): - self.assertEqual(len(m1.models), len(m2.models)) - for x, y in zip(m1.models, m2.models): - self._compare_pipelines(x, y) - else: - raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) - - def test_pipeline_persistence(self): - """ - Pipeline[HashingTF, PCA] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - pl = Pipeline(stages=[tf, pca]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_nested_pipeline_persistence(self): - """ - Pipeline[HashingTF, Pipeline[PCA]] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - p0 = Pipeline(stages=[pca]) - pl = Pipeline(stages=[tf, p0]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_python_transformer_pipeline_persistence(self): - """ - Pipeline[MockUnaryTransformer, Binarizer] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.range(0, 10).toDF('input') - tf = MockUnaryTransformer(shiftVal=2)\ - .setInputCol("input").setOutputCol("shiftedInput") - tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") - pl = Pipeline(stages=[tf, tf2]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_onevsrest(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))] * 10, - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self._compare_pipelines(ovr, loadedOvr) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - self._compare_pipelines(model, loadedModel) - - def test_decisiontree_classifier(self): - dt = DecisionTreeClassifier(maxDepth=1) - path = tempfile.mkdtemp() - dtc_path = path + "/dtc" - dt.save(dtc_path) - dt2 = DecisionTreeClassifier.load(dtc_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeClassifier instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeClassifier instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_decisiontree_regressor(self): - dt = DecisionTreeRegressor(maxDepth=1) - path = tempfile.mkdtemp() - dtr_path = path + "/dtr" - dt.save(dtr_path) - dt2 = DecisionTreeClassifier.load(dtr_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeRegressor instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeRegressor instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_default_read_write(self): - temp_path = tempfile.mkdtemp() - - lr = LogisticRegression() - lr.setMaxIter(50) - lr.setThreshold(.75) - writer = DefaultParamsWriter(lr) - - savePath = temp_path + "/lr" - writer.save(savePath) - - reader = DefaultParamsReadable.read() - lr2 = reader.load(savePath) - - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) - - # test overwrite - lr.setThreshold(.8) - writer.overwrite().save(savePath) - - reader = DefaultParamsReadable.read() - lr3 = reader.load(savePath) - - self.assertEqual(lr.uid, lr3.uid) - self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) - - def test_default_read_write_default_params(self): - lr = LogisticRegression() - self.assertFalse(lr.isSet(lr.getParam("threshold"))) - - lr.setMaxIter(50) - lr.setThreshold(.75) - - # `threshold` is set by user, default param `predictionCol` is not set by user. - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - writer = DefaultParamsWriter(lr) - metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) - self.assertTrue("defaultParamMap" in metadata) - - reader = DefaultParamsReadable.read() - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - # manually create metadata without `defaultParamMap` section. - del metadata['defaultParamMap'] - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): - reader.getAndSetParams(lr, loadedMetadata) - - # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. - metadata['sparkVersion'] = '2.3.0' - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - -class LDATest(SparkSessionTestCase): - - def _compare(self, m1, m2): - """ - Temp method for comparing instances. - TODO: Replace with generic implementation once SPARK-14706 is merged. - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - if m1.isDefined(p): - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) - if isinstance(m1, LDAModel): - self.assertEqual(m1.vocabSize(), m2.vocabSize()) - self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) - - def test_persistence(self): - # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - df = self.spark.createDataFrame([ - [1, Vectors.dense([0.0, 1.0])], - [2, Vectors.sparse(2, {0: 1.0})], - ], ["id", "features"]) - # Fit model - lda = LDA(k=2, seed=1, optimizer="em") - distributedModel = lda.fit(df) - self.assertTrue(distributedModel.isDistributed()) - localModel = distributedModel.toLocal() - self.assertFalse(localModel.isDistributed()) - # Define paths - path = tempfile.mkdtemp() - lda_path = path + "/lda" - dist_model_path = path + "/distLDAModel" - local_model_path = path + "/localLDAModel" - # Test LDA - lda.save(lda_path) - lda2 = LDA.load(lda_path) - self._compare(lda, lda2) - # Test DistributedLDAModel - distributedModel.save(dist_model_path) - distributedModel2 = DistributedLDAModel.load(dist_model_path) - self._compare(distributedModel, distributedModel2) - # Test LocalLDAModel - localModel.save(local_model_path) - localModel2 = LocalLDAModel.load(local_model_path) - self._compare(localModel, localModel2) - # Clean up - try: - rmtree(path) - except OSError: - pass - - -class TrainingSummaryTest(SparkSessionTestCase): - - def test_linear_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertAlmostEqual(s.explainedVariance, 0.25, 2) - self.assertAlmostEqual(s.meanAbsoluteError, 0.0) - self.assertAlmostEqual(s.meanSquaredError, 0.0) - self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) - self.assertAlmostEqual(s.r2, 1.0, 2) - self.assertAlmostEqual(s.r2adj, 1.0, 2) - self.assertTrue(isinstance(s.residuals, DataFrame)) - self.assertEqual(s.numInstances, 2) - self.assertEqual(s.degreesOfFreedom, 1) - devResiduals = s.devianceResiduals - self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class LinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) - - def test_glr_summary(self): - from pyspark.ml.linalg import Vectors - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", - fitIntercept=False) - model = glr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.numInstances, 2) - self.assertTrue(isinstance(s.residuals(), DataFrame)) - self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - self.assertEqual(s.degreesOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedomNull, 2) - self.assertEqual(s.rank, 1) - self.assertTrue(isinstance(s.solver, basestring)) - self.assertTrue(isinstance(s.aic, float)) - self.assertTrue(isinstance(s.deviance, float)) - self.assertTrue(isinstance(s.nullDeviance, float)) - self.assertTrue(isinstance(s.dispersion, float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class GeneralizedLinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.deviance, s.deviance) - - def test_binary_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) - - def test_multiclass_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], [])), - (2.0, 2.0, Vectors.dense(2.0)), - (2.0, 2.0, Vectors.dense(1.9))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) - - def test_gaussian_mixture_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - gmm = GaussianMixture(k=2) - model = gmm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertTrue(isinstance(s.probability, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 3) - - def test_bisecting_kmeans_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - bkm = BisectingKMeans(k=2) - model = bkm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 20) - - def test_kmeans_summary(self): - data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), - (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=2, seed=1) - model = kmeans.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 1) - - -class KMeansTests(SparkSessionTestCase): - - def test_kmeans_cosine_distance(self): - data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), - (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), - (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") - model = kmeans.fit(df) - result = model.transform(df).collect() - self.assertTrue(result[0].prediction == result[1].prediction) - self.assertTrue(result[2].prediction == result[3].prediction) - self.assertTrue(result[4].prediction == result[5].prediction) - - -class OneVsRestTests(SparkSessionTestCase): - - def test_copy(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - ovr1 = ovr.copy({lr.maxIter: 10}) - self.assertEqual(ovr.getClassifier().getMaxIter(), 5) - self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) - model = ovr.fit(df) - model1 = model.copy({model.predictionCol: "indexed"}) - self.assertEqual(model1.getPredictionCol(), "indexed") - - def test_output_columns(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, parallelism=1) - model = ovr.fit(df) - output = model.transform(df) - self.assertEqual(output.columns, ["label", "features", "prediction"]) - - def test_parallelism_doesnt_change_output(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) - modelPar1 = ovrPar1.fit(df) - ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) - modelPar2 = ovrPar2.fit(df) - for i, model in enumerate(modelPar1.models): - self.assertTrue(np.allclose(model.coefficients.toArray(), - modelPar2.models[i].coefficients.toArray(), atol=1E-4)) - self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) - - def test_support_for_weightCol(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), - (1.0, Vectors.sparse(2, [], []), 1.0), - (2.0, Vectors.dense(0.5, 0.5), 1.0)], - ["label", "features", "weight"]) - # classifier inherits hasWeightCol - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, weightCol="weight") - self.assertIsNotNone(ovr.fit(df)) - # classifier doesn't inherit hasWeightCol - dt = DecisionTreeClassifier() - ovr2 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr2.fit(df)) - - -class HashingTFTest(SparkSessionTestCase): - - def test_apply_binary_term_freqs(self): - - df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) - n = 10 - hashingTF = HashingTF() - hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) - output = hashingTF.transform(df) - features = output.select("features").first().features.toArray() - expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() - for i in range(0, n): - self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(features[i])) - - -class GeneralizedLinearRegressionTest(SparkSessionTestCase): - - def test_tweedie_distribution(self): - - df = self.spark.createDataFrame( - [(1.0, Vectors.dense(0.0, 0.0)), - (1.0, Vectors.dense(1.0, 2.0)), - (2.0, Vectors.dense(0.0, 0.0)), - (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) - - glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) - - model2 = glr.setLinkPower(-1.0).fit(df) - self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) - self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) - - def test_offset(self): - - df = self.spark.createDataFrame( - [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), - (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), - (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), - (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) - - glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], - atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) - - -class LinearRegressionTest(SparkSessionTestCase): - - def test_linear_regression_with_huber_loss(self): - - data_path = "data/mllib/sample_linear_regression_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lir = LinearRegression(loss="huber", epsilon=2.0) - model = lir.fit(df) - - expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, - 1.2612, -0.333, -0.5694, -0.6311, 0.6053] - expectedIntercept = 0.1607 - expectedScale = 9.758 - - self.assertTrue( - np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) - self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) - self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) - - -class LogisticRegressionTest(SparkSessionTestCase): - - def test_binomial_logistic_regression_with_bound(self): - - df = self.spark.createDataFrame( - [(1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.0, 3.0, Vectors.dense(2.0, 1.0)), - (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) - - lor = LogisticRegression(regParam=0.01, weightCol="weight", - lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), - upperBoundsOnIntercepts=Vectors.dense(0.0)) - model = lor.fit(df) - self.assertTrue( - np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) - - def test_multinomial_logistic_regression_with_bound(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lor = LogisticRegression(regParam=0.01, - lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), - upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) - model = lor.fit(df) - expected = [[4.593, 4.5516, 9.0099, 12.2904], - [1.0, 8.1093, 7.0, 10.0], - [3.041, 5.0, 8.0, 11.0]] - for i in range(0, len(expected)): - self.assertTrue( - np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) - self.assertTrue( - np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) - - -class MultilayerPerceptronClassifierTest(SparkSessionTestCase): - - def test_raw_and_probability_prediction(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], - blockSize=128, seed=123) - model = mlp.fit(df) - test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() - result = model.transform(test).head() - expected_prediction = 2.0 - expected_probability = [0.0, 0.0, 1.0] - expected_rawPrediction = [57.3955, -124.5462, 67.9943] - self.assertTrue(result.prediction, expected_prediction) - self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) - self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) - - -class FPGrowthTests(SparkSessionTestCase): - def setUp(self): - super(FPGrowthTests, self).setUp() - self.data = self.spark.createDataFrame( - [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], - ["items"]) - - def test_association_rules(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], - ["antecedent", "consequent", "confidence", "lift"] - ) - actual_association_rules = fpm.associationRules - - self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) - self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) - - def test_freq_itemsets(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_freq_itemsets = self.spark.createDataFrame( - [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], - ["items", "freq"] - ) - actual_freq_itemsets = fpm.freqItemsets - - self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) - self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) - - def tearDown(self): - del self.data - - -class ImageReaderTest(SparkSessionTestCase): - - def test_read_images(self): - data_path = 'data/mllib/images/origin/kittens' - df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - self.assertEqual(df.count(), 4) - first_row = df.take(1)[0][0] - array = ImageSchema.toNDArray(first_row) - self.assertEqual(len(array), first_row[1]) - self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) - self.assertEqual(df.schema, ImageSchema.imageSchema) - self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) - expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} - self.assertEqual(ImageSchema.ocvTypes, expected) - expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] - self.assertEqual(ImageSchema.imageFields, expected) - self.assertEqual(ImageSchema.undefinedImageType, "Undefined") - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "image argument should be pyspark.sql.types.Row; however", - lambda: ImageSchema.toNDArray("a")) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - ValueError, - "image argument should have attributes specified in", - lambda: ImageSchema.toNDArray(Row(a=1))) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "array argument should be numpy.ndarray; however, it got", - lambda: ImageSchema.toImage("a")) - - -class ImageReaderTest2(PySparkTestCase): - - @classmethod - def setUpClass(cls): - super(ImageReaderTest2, cls).setUpClass() - cls.hive_available = True - # Note that here we enable Hive's support. - cls.spark = None - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.tearDownClass() - cls.hive_available = False - except TypeError: - cls.tearDownClass() - cls.hive_available = False - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - super(ImageReaderTest2, cls).tearDownClass() - if cls.spark is not None: - cls.spark.sparkSession.stop() - cls.spark = None - - def test_read_images_multiple_times(self): - # This test case is to check if `ImageSchema.readImages` tries to - # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/origin/kittens' - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - - -class ALSTest(SparkSessionTestCase): - - def test_storage_levels(self): - df = self.spark.createDataFrame( - [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], - ["user", "item", "rating"]) - als = ALS().setMaxIter(1).setRank(1) - # test default params - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") - # test non-default params - als.setIntermediateStorageLevel("MEMORY_ONLY_2") - als.setFinalStorageLevel("DISK_ONLY") - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") - - -class DefaultValuesTests(PySparkTestCase): - """ - Test :py:class:`JavaParams` classes to see if their default Param values match - those in their Scala counterparts. - """ - - def test_java_params(self): - import pyspark.ml.feature - import pyspark.ml.classification - import pyspark.ml.clustering - import pyspark.ml.evaluation - import pyspark.ml.pipeline - import pyspark.ml.recommendation - import pyspark.ml.regression - - modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, - pyspark.ml.regression] - for module in modules: - for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and not name.endswith('Params')\ - and issubclass(cls, JavaParams) and not inspect.isabstract(cls): - # NOTE: disable check_params_exist until there is parity with Scala API - ParamTests.check_params(self, cls(), check_params_exist=False) - - # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel - ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), - check_params_exist=False) - ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), - check_params_exist=False) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), - Row(label=0.0, features=self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -class WrapperTests(MLlibTestCase): - - def test_new_java_array(self): - # test array of strings - str_list = ["a", "b", "c"] - java_class = self.sc._gateway.jvm.java.lang.String - java_array = JavaWrapper._new_java_array(str_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), str_list) - # test array of integers - int_list = [1, 2, 3] - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array(int_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), int_list) - # test array of floats - float_list = [0.1, 0.2, 0.3] - java_class = self.sc._gateway.jvm.java.lang.Double - java_array = JavaWrapper._new_java_array(float_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), float_list) - # test array of bools - bool_list = [False, True, True] - java_class = self.sc._gateway.jvm.java.lang.Boolean - java_array = JavaWrapper._new_java_array(bool_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), bool_list) - # test array of Java DenseVectors - v1 = DenseVector([0.0, 1.0]) - v2 = DenseVector([1.0, 0.0]) - vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] - java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector - java_array = JavaWrapper._new_java_array(vec_java_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) - # test empty array - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array([], java_class) - self.assertEqual(_java2py(self.sc, java_array), []) - - -class ChiSquareTestTests(SparkSessionTestCase): - - def test_chisquaretest(self): - data = [[0, Vectors.dense([0, 1, 2])], - [1, Vectors.dense([1, 1, 1])], - [2, Vectors.dense([2, 1, 0])]] - df = self.spark.createDataFrame(data, ['label', 'feat']) - res = ChiSquareTest.test(df, 'feat', 'label') - # This line is hitting the collect bug described in #17218, commented for now. - # pValues = res.select("degreesOfFreedom").collect()) - self.assertIsInstance(res, DataFrame) - fieldNames = set(field.name for field in res.schema.fields) - expectedFields = ["pValues", "degreesOfFreedom", "statistics"] - self.assertTrue(all(field in fieldNames for field in expectedFields)) - - -class UnaryTransformerTests(SparkSessionTestCase): - - def test_unary_transformer_validate_input_type(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - # should not raise any errors - transformer.validateInputType(DoubleType()) - - with self.assertRaises(TypeError): - # passing the wrong input type should raise an error - transformer.validateInputType(IntegerType()) - - def test_unary_transformer_transform(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - df = self.spark.range(0, 10).toDF('input') - df = df.withColumn("input", df.input.cast(dataType="double")) - - transformed_df = transformer.transform(df) - results = transformed_df.select("input", "output").collect() - - for res in results: - self.assertEqual(res.input + shiftVal, res.output) - - -class EstimatorTest(unittest.TestCase): - - def testDefaultFitMultiple(self): - N = 4 - data = MockDataset() - estimator = MockEstimator() - params = [{estimator.fake: i} for i in range(N)] - modelIter = estimator.fitMultiple(data, params) - indexList = [] - for index, model in modelIter: - self.assertEqual(model.getFake(), index) - indexList.append(index) - self.assertEqual(sorted(indexList), list(range(N))) - - -if __name__ == "__main__": - from pyspark.ml.tests import * - - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.ml_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py deleted file mode 100644 index 653f5cb9ff4a2..0000000000000 --- a/python/pyspark/mllib/tests.py +++ /dev/null @@ -1,1788 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Fuller unit tests for Python MLlib. -""" - -import os -import sys -import tempfile -import array as pyarray -from math import sqrt -from time import time, sleep -from shutil import rmtree - -import unishark -from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) -from numpy import sum as array_sum - -from py4j.protocol import Py4JJavaError - -if sys.version > '3': - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark import SparkContext -import pyspark.ml.linalg as newlinalg -from pyspark.mllib.common import _to_java_object_rdd -from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.linalg.distributed import RowMatrix -from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD -from pyspark.mllib.fpm import FPGrowth -from pyspark.mllib.recommendation import Rating -from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD -from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.stat import Statistics -from pyspark.mllib.feature import HashingTF -from pyspark.mllib.feature import Word2Vec -from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler, ElementwiseProduct -from pyspark.mllib.util import LinearDataGenerator -from pyspark.mllib.util import MLUtils -from pyspark.serializers import PickleSerializer -from pyspark.streaming import StreamingContext -from pyspark.sql import SparkSession -from pyspark.sql.utils import IllegalArgumentException -from pyspark.streaming import StreamingContext - -_have_scipy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class MLLibStreamingTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.ssc = StreamingContext(self.sc, 1.0) - - def tearDown(self): - self.ssc.stop(False) - self.sc.stop() - - @staticmethod - def _eventually(condition, timeout=30.0, catch_assertions=False): - """ - Wait a given amount of time for a condition to pass, else fail with an error. - This is a helper utility for streaming ML tests. - :param condition: Function that checks for termination conditions. - condition() can return: - - True: Conditions met. Return without error. - - other value: Conditions not met yet. Continue. Upon timeout, - include last such value in error message. - Note that this method may be called at any time during - streaming execution (e.g., even before any results - have been created). - :param timeout: Number of seconds to wait. Default 30 seconds. - :param catch_assertions: If False (default), do not catch AssertionErrors. - If True, catch AssertionErrors; continue, but save - error to throw upon timeout. - """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue)) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_parse_vector(self): - a = DenseVector([]) - self.assertEqual(str(a), '[]') - self.assertEqual(Vectors.parse(str(a)), a) - a = DenseVector([3, 4, 6, 7]) - self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(4, [], []) - self.assertEqual(str(a), '(4,[],[])') - self.assertEqual(SparseVector.parse(str(a)), a) - a = SparseVector(4, [0, 2], [3, 4]) - self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(10, [0, 1], [4, 5]) - self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - def test_ml_mllib_vector_conversion(self): - # to ml - # dense - mllibDV = Vectors.dense([1, 2, 3]) - mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) - mlDV2 = mllibDV.asML() - self.assertEqual(mlDV2, mlDV1) - # sparse - mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV2 = mllibSV.asML() - self.assertEqual(mlSV2, mlSV1) - # from ml - # dense - mllibDV1 = Vectors.dense([1, 2, 3]) - mlDV = newlinalg.Vectors.dense([1, 2, 3]) - mllibDV2 = Vectors.fromML(mlDV) - self.assertEqual(mllibDV1, mllibDV2) - # sparse - mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mllibSV2 = Vectors.fromML(mlSV) - self.assertEqual(mllibSV1, mllibSV2) - - def test_ml_mllib_matrix_conversion(self): - # to ml - # dense - mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM2 = mllibDM.asML() - self.assertEqual(mlDM2, mlDM1) - # transposed - mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt2 = mllibDMt.asML() - self.assertEqual(mlDMt2, mlDMt1) - # sparse - mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM2 = mllibSM.asML() - self.assertEqual(mlSM2, mlSM1) - # transposed - mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt2 = mllibSMt.asML() - self.assertEqual(mlSMt2, mlSMt1) - # from ml - # dense - mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) - mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) - mllibDM2 = Matrices.fromML(mlDM) - self.assertEqual(mllibDM1, mllibDM2) - # transposed - mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) - mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) - mllibDMt2 = Matrices.fromML(mlDMt) - self.assertEqual(mllibDMt1, mllibDMt2) - # sparse - mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mllibSM2 = Matrices.fromML(mlSM) - self.assertEqual(mllibSM1, mllibSM2) - # transposed - mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mllibSMt2 = Matrices.fromML(mlSMt) - self.assertEqual(mllibSMt1, mllibSMt2) - - -class ListTests(MLlibTestCase): - - """ - Test MLlib algorithms on plain lists, to make sure they're passed through - as NumPy arrays. - """ - - def test_bisecting_kmeans(self): - from pyspark.mllib.clustering import BisectingKMeans - data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) - bskm = BisectingKMeans() - model = bskm.train(self.sc.parallelize(data, 2), k=4) - p = array([0.0, 0.0]) - rdd_p = self.sc.parallelize([p]) - self.assertEqual(model.predict(p), model.predict(rdd_p).first()) - self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) - self.assertEqual(model.k, len(model.clusterCenters)) - - def test_kmeans(self): - from pyspark.mllib.clustering import KMeans - data = [ - [0, 1.1], - [0, 1.2], - [1.1, 0], - [1.2, 0], - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", - initializationSteps=7, epsilon=1e-4) - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_kmeans_deterministic(self): - from pyspark.mllib.clustering import KMeans - X = range(0, 100, 10) - Y = range(0, 100, 10) - data = [[x, y] for x, y in zip(X, Y)] - clusters1 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - clusters2 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - centers1 = clusters1.centers - centers2 = clusters2.centers - for c1, c2 in zip(centers1, centers2): - # TODO: Allow small numeric difference. - self.assertTrue(array_equal(c1, c2)) - - def test_gmm(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - [1, 2], - [8, 9], - [-4, -3], - [-6, -7], - ]) - clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=1) - labels = clusters.predict(data).collect() - self.assertEqual(labels[0], labels[1]) - self.assertEqual(labels[2], labels[3]) - - def test_gmm_deterministic(self): - from pyspark.mllib.clustering import GaussianMixture - x = range(0, 100, 10) - y = range(0, 100, 10) - data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) - clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEqual(round(c1, 7), round(c2, 7)) - - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ - RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel - data = [ - LabeledPoint(0.0, [1, 0, 0]), - LabeledPoint(1.0, [0, 1, 1]), - LabeledPoint(0.0, [2, 0, 0]), - LabeledPoint(1.0, [0, 2, 1]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - temp_dir = tempfile.mkdtemp() - - lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd, iterations=10) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - dt_model_dir = os.path.join(temp_dir, "dt") - dt_model.save(self.sc, dt_model_dir) - same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) - self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) - - rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, - maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - rf_model_dir = os.path.join(temp_dir, "rf") - rf_model.save(self.sc, rf_model_dir) - same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) - self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) - - gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - gbt_model_dir = os.path.join(temp_dir, "gbt") - gbt_model.save(self.sc, gbt_model_dir) - same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) - self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) - - try: - rmtree(temp_dir) - except OSError: - pass - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees - data = [ - LabeledPoint(-1.0, [0, -1]), - LabeledPoint(1.0, [0, 1]), - LabeledPoint(-1.0, [0, -2]), - LabeledPoint(1.0, [0, 2]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd, iterations=10) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - except ValueError: - self.fail() - - # Verify that maxBins is being passed through - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) - with self.assertRaises(Exception) as cm: - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) - - -class StatTests(MLlibTestCase): - # SPARK-4023 - def test_col_with_different_rdds(self): - # numpy - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(1000, summary.count()) - # array - data = self.sc.parallelize([range(10)] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - # array - data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - - def test_col_norms(self): - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(10, len(summary.normL1())) - self.assertEqual(10, len(summary.normL2())) - - data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) - summary2 = Statistics.colStats(data2) - self.assertEqual(array([45.0]), summary2.normL1()) - import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) - self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(MLlibTestCase): - - """ - Test both vector operations and MLlib algorithms with SciPy sparse matrices, - if SciPy is available. - """ - - def test_serialize(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(lil)) - self.assertEqual(sv, _convert_to_vector(lil.tocsc())) - self.assertEqual(sv, _convert_to_vector(lil.tocoo())) - self.assertEqual(sv, _convert_to_vector(lil.tocsr())) - self.assertEqual(sv, _convert_to_vector(lil.todok())) - - def serialize(l): - return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEqual(sv, serialize(lil)) - self.assertEqual(sv, serialize(lil.tocsc())) - self.assertEqual(sv, serialize(lil.tocsr())) - self.assertEqual(sv, serialize(lil.todok())) - - def test_convert_to_vector(self): - from scipy.sparse import csc_matrix - # Create a CSC matrix with non-sorted indices - indptr = array([0, 2]) - indices = array([3, 1]) - data = array([2.0, 1.0]) - csc = csc_matrix((data, indices, indptr)) - self.assertFalse(csc.has_sorted_indices) - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(csc)) - - def test_dot(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEqual(10.0, dv.dot(lil)) - - def test_squared_distance(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 3 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEqual(15.0, dv.squared_distance(lil)) - self.assertEqual(15.0, sv.squared_distance(lil)) - - def scipy_matrix(self, size, values): - """Create a column SciPy matrix from a dictionary of values""" - from scipy.sparse import lil_matrix - lil = lil_matrix((size, 1)) - for key, value in values.items(): - lil[key, 0] = value - return lil - - def test_clustering(self): - from pyspark.mllib.clustering import KMeans - data = [ - self.scipy_matrix(3, {1: 1.0}), - self.scipy_matrix(3, {1: 1.1}), - self.scipy_matrix(3, {2: 1.0}), - self.scipy_matrix(3, {2: 1.1}) - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LogisticRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - -class ChiSqTestTests(MLlibTestCase): - def test_goodness_of_fit(self): - from numpy import inf - - observed = Vectors.dense([4, 6, 5]) - pearson = Statistics.chiSqTest(observed) - - # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` - self.assertEqual(pearson.statistic, 0.4) - self.assertEqual(pearson.degreesOfFreedom, 2) - self.assertAlmostEqual(pearson.pValue, 0.8187, 4) - - # Different expected and observed sum - observed1 = Vectors.dense([21, 38, 43, 80]) - expected1 = Vectors.dense([3, 5, 7, 20]) - pearson1 = Statistics.chiSqTest(observed1, expected1) - - # Results validated against the R command - # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` - self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) - self.assertEqual(pearson1.degreesOfFreedom, 3) - self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) - - # Vectors with different sizes - observed3 = Vectors.dense([1.0, 2.0, 3.0]) - expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) - self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) - - # Negative counts in observed - neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) - - # Count = 0.0 in expected but not observed - zero_expected = Vectors.dense([1.0, 0.0, 3.0]) - pearson_inf = Statistics.chiSqTest(observed, zero_expected) - self.assertEqual(pearson_inf.statistic, inf) - self.assertEqual(pearson_inf.degreesOfFreedom, 2) - self.assertEqual(pearson_inf.pValue, 0.0) - - # 0.0 in expected and observed simultaneously - zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises( - IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) - - def test_matrix_independence(self): - data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] - chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - - # Results validated against R command - # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` - self.assertAlmostEqual(chi.statistic, 21.9958, 4) - self.assertEqual(chi.degreesOfFreedom, 6) - self.assertAlmostEqual(chi.pValue, 0.001213, 4) - - # Negative counts - neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) - - # Row sum = 0.0 - row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) - - # Column sum = 0.0 - col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) - - def test_chi_sq_pearson(self): - data = [ - LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), - LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), - LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), - LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) - ] - - for numParts in [2, 4, 6, 8]: - chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) - feature1 = chi[0] - self.assertEqual(feature1.statistic, 0.75) - self.assertEqual(feature1.degreesOfFreedom, 2) - self.assertAlmostEqual(feature1.pValue, 0.6873, 4) - - feature2 = chi[1] - self.assertEqual(feature2.statistic, 1.5) - self.assertEqual(feature2.degreesOfFreedom, 3) - self.assertAlmostEqual(feature2.pValue, 0.6823, 4) - - def test_right_number_of_results(self): - num_cols = 1001 - sparse_data = [ - LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), - LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) - ] - chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) - self.assertEqual(len(chi), num_cols) - self.assertIsNotNone(chi[1000]) - - -class KolmogorovSmirnovTest(MLlibTestCase): - - def test_R_implementation_equivalence(self): - data = self.sc.parallelize([ - 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, - -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, - -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, - -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, - 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ]) - model = Statistics.kolmogorovSmirnovTest(data, "norm") - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - -class SerDeTest(MLlibTestCase): - def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) - self.assertEqual(_to_java_object_rdd(data).count(), 10) - - -class FeatureTest(MLlibTestCase): - def test_idf_model(self): - data = [ - Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), - Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), - Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), - Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) - ] - model = IDF().fit(self.sc.parallelize(data, 2)) - idf = model.idf() - self.assertEqual(len(idf), 11) - - -class Word2VecTests(MLlibTestCase): - def test_word2vec_setters(self): - model = Word2Vec() \ - .setVectorSize(2) \ - .setLearningRate(0.01) \ - .setNumPartitions(2) \ - .setNumIterations(10) \ - .setSeed(1024) \ - .setMinCount(3) \ - .setWindowSize(6) - self.assertEqual(model.vectorSize, 2) - self.assertTrue(model.learningRate < 0.02) - self.assertEqual(model.numPartitions, 2) - self.assertEqual(model.numIterations, 10) - self.assertEqual(model.seed, 1024) - self.assertEqual(model.minCount, 3) - self.assertEqual(model.windowSize, 6) - - def test_word2vec_get_vectors(self): - data = [ - ["a", "b", "c", "d", "e", "f", "g"], - ["a", "b", "c", "d", "e", "f"], - ["a", "b", "c", "d", "e"], - ["a", "b", "c", "d"], - ["a", "b", "c"], - ["a", "b"], - ["a"] - ] - model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEqual(len(model.getVectors()), 3) - - -class StandardScalerTests(MLlibTestCase): - def test_model_setters(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertIsNotNone(model.setWithMean(True)) - self.assertIsNotNone(model.setWithStd(True)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) - - def test_model_transform(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) - - -class ElementwiseProductTests(MLlibTestCase): - def test_model_transform(self): - weight = Vectors.dense([3, 2, 1]) - - densevec = Vectors.dense([4, 5, 6]) - sparsevec = Vectors.sparse(3, [0], [1]) - eprod = ElementwiseProduct(weight) - self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) - self.assertEqual( - eprod.transform(sparsevec), SparseVector(3, [0], [3])) - - -class StreamingKMeansTest(MLLibStreamingTestCase): - def test_model_params(self): - """Test that the model params are set correctly""" - stkm = StreamingKMeans() - stkm.setK(5).setDecayFactor(0.0) - self.assertEqual(stkm._k, 5) - self.assertEqual(stkm._decayFactor, 0.0) - - # Model not set yet. - self.assertIsNone(stkm.latestModel()) - self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) - - stkm.setInitialCenters( - centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEqual( - stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) - - def test_accuracy_for_single_center(self): - """Test that parameters obtained are correct for a single center.""" - centers, batches = self.streamingKMeansDataGenerator( - batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) - stkm = StreamingKMeans(1) - stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) - input_stream = self.ssc.queueStream( - [self.sc.parallelize(batch, 1) for batch in batches]) - stkm.trainOn(input_stream) - - self.ssc.start() - - def condition(): - self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) - return True - self._eventually(condition, catch_assertions=True) - - realCenters = array_sum(array(centers), axis=0) - for i in range(5): - modelCenters = stkm.latestModel().centers[0][i] - self.assertAlmostEqual(centers[0][i], modelCenters, 1) - self.assertAlmostEqual(realCenters[i], modelCenters, 1) - - def streamingKMeansDataGenerator(self, batches, numPoints, - k, d, r, seed, centers=None): - rng = random.RandomState(seed) - - # Generate centers. - centers = [rng.randn(d) for i in range(k)] - - return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) - for j in range(numPoints)] - for i in range(batches)] - - def test_trainOn_model(self): - """Test the model on toy data with four clusters.""" - stkm = StreamingKMeans() - initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] - stkm.setInitialCenters( - centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - - # Create a toy dataset by setting a tiny offset for each point. - offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] - batches = [] - for offset in offsets: - batches.append([[offset[0] + center[0], offset[1] + center[1]] - for center in initCenters]) - - batches = [self.sc.parallelize(batch, 1) for batch in batches] - input_stream = self.ssc.queueStream(batches) - stkm.trainOn(input_stream) - self.ssc.start() - - # Give enough time to train the model. - def condition(): - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) - return True - self._eventually(condition, catch_assertions=True) - - def test_predictOn_model(self): - """Test that the model predicts correctly on toy data.""" - stkm = StreamingKMeans() - stkm._model = StreamingKMeansModel( - clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], - clusterWeights=[1.0, 1.0, 1.0, 1.0]) - - predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] - predict_stream = self.ssc.queueStream(predict_data) - predict_val = stkm.predictOn(predict_stream) - - result = [] - - def update(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - result.append(rdd_collect) - - predict_val.foreachRDD(update) - self.ssc.start() - - def condition(): - self.assertEqual(result, [[0], [1], [2], [3]]) - return True - - self._eventually(condition, catch_assertions=True) - - @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") - def test_trainOn_predictOn(self): - """Test that prediction happens on the updated model.""" - stkm = StreamingKMeans(decayFactor=0.0, k=2) - stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) - - # Since decay factor is set to zero, once the first batch - # is passed the clusterCenters are updated to [-0.5, 0.7] - # which causes 0.2 & 0.3 to be classified as 1, even though the - # classification based in the initial model would have been 0 - # proving that the model is updated. - batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [self.sc.parallelize(batch) for batch in batches] - input_stream = self.ssc.queueStream(batches) - predict_results = [] - - def collect(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - predict_results.append(rdd_collect) - - stkm.trainOn(input_stream) - predict_stream = stkm.predictOn(input_stream) - predict_stream.foreachRDD(collect) - - self.ssc.start() - - def condition(): - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) - return True - - self._eventually(condition, catch_assertions=True) - - -class LinearDataGeneratorTests(MLlibTestCase): - def test_dim(self): - linear_data = LinearDataGenerator.generateLinearInput( - intercept=0.0, weights=[0.0, 0.0, 0.0], - xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], - nPoints=4, seed=0, eps=0.1) - self.assertEqual(len(linear_data), 4) - for point in linear_data: - self.assertEqual(len(point.features), 3) - - linear_data = LinearDataGenerator.generateLinearRDD( - sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, - nParts=2, intercept=0.0).collect() - self.assertEqual(len(linear_data), 6) - for point in linear_data: - self.assertEqual(len(point.features), 2) - - -class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): - - @staticmethod - def generateLogisticInput(offset, scale, nPoints, seed): - """ - Generate 1 / (1 + exp(-x * scale + offset)) - - where, - x is randomnly distributed and the threshold - and labels for each sample in x is obtained from a random uniform - distribution. - """ - rng = random.RandomState(seed) - x = rng.randn(nPoints) - sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) - y_p = rng.rand(nPoints) - cut_off = y_p <= sigmoid - y_p[cut_off] = 1.0 - y_p[~cut_off] = 0.0 - return [ - LabeledPoint(y_p[i], Vectors.dense([x[i]])) - for i in range(nPoints)] - - @unittest.skip("Super flaky test") - def test_parameter_accuracy(self): - """ - Test that the final value of weights is close to the desired value. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - - self.ssc.start() - - def condition(): - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_convergence(self): - """ - Test that weights converge to the required value on toy data. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - models = [] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - input_stream.foreachRDD( - lambda x: models.append(slr.latestModel().weights[0])) - - self.ssc.start() - - def condition(): - self.assertEqual(len(models), len(input_batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, 60.0, catch_assertions=True) - - t_models = array(models) - diff = t_models[1:] - t_models[:-1] - # Test that weights improve with a small tolerance - self.assertTrue(all(diff >= -0.1)) - self.assertTrue(array_sum(diff > 0) > 1) - - @staticmethod - def calculate_accuracy_error(true, predicted): - return sum(abs(array(true) - array(predicted))) / len(true) - - def test_predictions(self): - """Test predicted values on a toy model.""" - input_batches = [] - for i in range(20): - batch = self.sc.parallelize( - self.generateLogisticInput(0, 1.5, 100, 42 + i)) - input_batches.append(batch.map(lambda x: (x.label, x.features))) - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([1.5]) - predict_stream = slr.predictOnValues(input_stream) - true_predicted = [] - predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - self.ssc.start() - - def condition(): - self.assertEqual(len(true_predicted), len(input_batches)) - return True - - self._eventually(condition, catch_assertions=True) - - # Test that the accuracy error is no more than 0.4 on each batch. - for batch in true_predicted: - true, predicted = zip(*batch) - self.assertTrue( - self.calculate_accuracy_error(true, predicted) < 0.4) - - @unittest.skip("Super flaky test") - def test_training_and_prediction(self): - """Test that the model improves on toy data with no. of batches""" - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.01, numIterations=25) - slr.setInitialWeights([-0.1]) - errors = [] - - def collect_errors(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(self.calculate_accuracy_error(true, predicted)) - - true_predicted = [] - input_stream = self.ssc.queueStream(input_batches) - predict_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - ps = slr.predictOnValues(predict_stream) - ps.foreachRDD(lambda x: collect_errors(x)) - - self.ssc.start() - - def condition(): - # Test that the improvement in error is > 0.3 - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 0.3) - if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): - - def assertArrayAlmostEqual(self, array1, array2, dec): - for i, j in array1, array2: - self.assertAlmostEqual(i, j, dec) - - @unittest.skip("Super flaky test") - def test_parameter_accuracy(self): - """Test that coefs are predicted accurately by fitting on toy data.""" - - # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients - # (10, 10) - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0, 0.0]) - xMean = [0.0, 0.0] - xVariance = [1.0 / 3.0, 1.0 / 3.0] - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - input_stream = self.ssc.queueStream(batches) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_parameter_convergence(self): - """Test that the model parameters improve with streaming data.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - model_weights = [] - input_stream = self.ssc.queueStream(batches) - input_stream.foreachRDD( - lambda x: model_weights.append(slr.latestModel().weights[0])) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertEqual(len(model_weights), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - w = array(model_weights) - diff = w[1:] - w[:-1] - self.assertTrue(all(diff >= -0.1)) - - def test_prediction(self): - """Test prediction on a model with weights already set.""" - # Create a model with initial Weights equal to coefs - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([10.0, 10.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], - 100, 42 + i, 0.1) - batches.append( - self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) - - input_stream = self.ssc.queueStream(batches) - output_stream = slr.predictOnValues(input_stream) - samples = [] - output_stream.foreachRDD(lambda x: samples.append(x.collect())) - - self.ssc.start() - - def condition(): - self.assertEqual(len(samples), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - # Test that mean absolute error on each batch is less than 0.1 - for batch in samples: - true, predicted = zip(*batch) - self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) - - @unittest.skip("Super flaky test") - def test_train_prediction(self): - """Test that error on test data improves as model is trained.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in batches] - errors = [] - - def func(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(mean(abs(true) - abs(predicted))) - - input_stream = self.ssc.queueStream(batches) - output_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - output_stream = slr.predictOnValues(output_stream) - output_stream.foreachRDD(func) - self.ssc.start() - - def condition(): - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 2) - if len(errors) >= 3 and errors[1] - errors[-1] > 2: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class MLUtilsTests(MLlibTestCase): - def test_append_bias(self): - data = [2.0, 2.0, 2.0] - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_vector(self): - data = Vectors.dense([2.0, 2.0, 2.0]) - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_sp_vector(self): - data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) - expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) - # Returned value must be SparseVector - ret = MLUtils.appendBias(data) - self.assertEqual(ret, expected) - self.assertEqual(type(ret), SparseVector) - - def test_load_vectors(self): - import shutil - data = [ - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0] - ] - temp_dir = tempfile.mkdtemp() - load_vectors_path = os.path.join(temp_dir, "test_load_vectors") - try: - self.sc.parallelize(data).saveAsTextFile(load_vectors_path) - ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) - ret = ret_rdd.collect() - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) - self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) - except: - self.fail() - finally: - shutil.rmtree(load_vectors_path) - - -class ALSTests(MLlibTestCase): - - def test_als_ratings_serialize(self): - r = Rating(7, 1123, 3.14) - jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) - nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) - self.assertEqual(r.user, nr.user) - self.assertEqual(r.product, nr.product) - self.assertAlmostEqual(r.rating, nr.rating, 2) - - def test_als_ratings_id_long_error(self): - r = Rating(1205640308657491975, 50233468418, 1.0) - # rating user id exceeds max int value, should fail when pickled - self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, - bytearray(ser.dumps(r))) - - -class HashingTFTest(MLlibTestCase): - - def test_binary_term_freqs(self): - hashingTF = HashingTF(100).setBinary(True) - doc = "a a b c c c".split(" ") - n = hashingTF.numFeatures - output = hashingTF.transform(doc).toArray() - expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, - hashingTF.indexOf("b"): 1.0, - hashingTF.indexOf("c"): 1.0}).toArray() - for i in range(0, n): - self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(output[i])) - - -class DimensionalityReductionTests(MLlibTestCase): - - denseData = [ - Vectors.dense([0.0, 1.0, 2.0]), - Vectors.dense([3.0, 4.0, 5.0]), - Vectors.dense([6.0, 7.0, 8.0]), - Vectors.dense([9.0, 0.0, 1.0]) - ] - sparseData = [ - Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), - Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), - Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), - Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) - ] - - def assertEqualUpToSign(self, vecA, vecB): - eq1 = vecA - vecB - eq2 = vecA + vecB - self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) - - def test_svd(self): - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - m = 4 - n = 3 - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - rm = mat.computeSVD(k, computeU=True) - self.assertEqual(rm.s.size, k) - self.assertEqual(rm.U.numRows(), m) - self.assertEqual(rm.U.numCols(), k) - self.assertEqual(rm.V.numRows, n) - self.assertEqual(rm.V.numCols, k) - - # Test that U returned is None if computeU is set to False. - self.assertEqual(mat.computeSVD(1).U, None) - - # Test that low rank matrices cannot have number of singular values - # greater than a limit. - rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) - self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) - - def test_pca(self): - expected_pcs = array([ - [0.0, 1.0, 0.0], - [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], - [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] - ]) - n = 3 - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - pcs = mat.computePrincipalComponents(k) - self.assertEqual(pcs.numRows, n) - self.assertEqual(pcs.numCols, k) - - # We can just test the updated principal component for equality. - self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) - - -class FPGrowthTest(MLlibTestCase): - - def test_fpgrowth(self): - data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] - rdd = self.sc.parallelize(data, 2) - model1 = FPGrowth.train(rdd, 0.6, 2) - # use default data partition number when numPartitions is not specified - model2 = FPGrowth.train(rdd, 0.6) - self.assertEqual(sorted(model1.freqItemsets().collect()), - sorted(model2.freqItemsets().collect())) - -if __name__ == "__main__": - from pyspark.mllib.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.mllib_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py deleted file mode 100644 index 206d529454f3d..0000000000000 --- a/python/pyspark/sql/tests.py +++ /dev/null @@ -1,7109 +0,0 @@ -# -*- encoding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for pyspark.sql; additional tests are implemented as doctests in -individual modules. -""" -import os -import sys -import subprocess -import pydoc -import shutil -import tempfile -import threading -import pickle -import functools -import time -import datetime -import array -import ctypes -import warnings -import py4j -from contextlib import contextmanager -import unishark - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark.util import _exception_message - -_pandas_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() -except ImportError as e: - # If Pandas version requirement is not satisfied, skip related tests. - _pandas_requirement_message = _exception_message(e) - -_pyarrow_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() -except ImportError as e: - # If Arrow version requirement is not satisfied, skip related tests. - _pyarrow_requirement_message = _exception_message(e) - -_test_not_compiled_message = None -try: - from pyspark.sql.utils import require_test_compiled - require_test_compiled() -except Exception as e: - _test_not_compiled_message = _exception_message(e) - -_have_pandas = _pandas_requirement_message is None -_have_pyarrow = _pyarrow_requirement_message is None -_test_compiled = _test_not_compiled_message is None - -from pyspark import SparkConf, SparkContext -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row -from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier -from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings -from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings -from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2, lit -from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException - - -class UTCOffsetTimezone(datetime.tzinfo): - """ - Specifies timezone in UTC offset - """ - - def __init__(self, offset=0): - self.ZERO = datetime.timedelta(hours=offset) - - def utcoffset(self, dt): - return self.ZERO - - def dst(self, dt): - return self.ZERO - - -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.sql.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, self.__class__) and \ - other.x == self.x and other.y == self.y - - -class PythonOnlyUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return '__main__' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return PythonOnlyPoint(datum[0], datum[1]) - - @staticmethod - def foo(): - pass - - @property - def props(self): - return {} - - -class PythonOnlyPoint(ExamplePoint): - """ - An example class to demonstrate UDT in only Python - """ - __UDT__ = PythonOnlyUDT() - - -class MyObject(object): - def __init__(self, key, value): - self.key = key - self.value = value - - -class SQLTestUtils(object): - """ - This util assumes the instance of this to have 'spark' attribute, having a spark session. - It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the - the implementation of this class has 'spark' attribute. - """ - - @contextmanager - def sql_conf(self, pairs): - """ - A convenient context manager to test some configuration specific logic. This sets - `value` to the configuration `key` and then restores it back when it exits. - """ - assert isinstance(pairs, dict), "pairs should be a dictionary." - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - keys = pairs.keys() - new_values = pairs.values() - old_values = [self.spark.conf.get(key, None) for key in keys] - for key, new_value in zip(keys, new_values): - self.spark.conf.set(key, new_value) - try: - yield - finally: - for key, old_value in zip(keys, old_values): - if old_value is None: - self.spark.conf.unset(key) - else: - self.spark.conf.set(key, old_value) - - @contextmanager - def database(self, *databases): - """ - A convenient context manager to test with some specific databases. This drops the given - databases if exist and sets current database to "default" when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for db in databases: - self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) - self.spark.catalog.setCurrentDatabase("default") - - @contextmanager - def table(self, *tables): - """ - A convenient context manager to test with some specific tables. This drops the given tables - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for t in tables: - self.spark.sql("DROP TABLE IF EXISTS %s" % t) - - @contextmanager - def tempView(self, *views): - """ - A convenient context manager to test with some specific views. This drops the given views - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for v in views: - self.spark.catalog.dropTempView(v) - - @contextmanager - def function(self, *functions): - """ - A convenient context manager to test with some specific functions. This drops the given - functions if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for f in functions: - self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) - - -class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): - @classmethod - def setUpClass(cls): - super(ReusedSQLTestCase, cls).setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - super(ReusedSQLTestCase, cls).tearDownClass() - cls.spark.stop() - - def assertPandasEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + - "\n\nResult:\n%s\n%s" % (result, result.dtypes)) - self.assertTrue(expected.equals(result), msg=msg) - - -class DataTypeTests(unittest.TestCase): - # regression test for SPARK-6055 - def test_data_type_eq(self): - lt = LongType() - lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEqual(lt, lt2) - - # regression test for SPARK-7978 - def test_decimal_type(self): - t1 = DecimalType() - t2 = DecimalType(10, 2) - self.assertTrue(t2 is not t1) - self.assertNotEqual(t1, t2) - t3 = DecimalType(8) - self.assertNotEqual(t2, t3) - - # regression test for SPARK-10392 - def test_datetype_equal_zero(self): - dt = DateType() - self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) - - # regression test for SPARK-17035 - def test_timestamp_microsecond(self): - tst = TimestampType() - self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999) - - def test_empty_row(self): - row = Row() - self.assertEqual(len(row), 0) - - def test_struct_field_type_name(self): - struct_field = StructField("a", IntegerType()) - self.assertRaises(TypeError, struct_field.typeName) - - def test_invalid_create_row(self): - row_class = Row("c1", "c2") - self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) - - -class SparkSessionBuilderTests(unittest.TestCase): - - def test_create_spark_context_first_then_spark_session(self): - sc = None - session = None - try: - conf = SparkConf().set("key1", "value1") - sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) - session = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session.conf.get("key1"), "value1") - self.assertEqual(session.conf.get("key2"), "value2") - self.assertEqual(session.sparkContext, sc) - - self.assertFalse(sc.getConf().contains("key2")) - self.assertEqual(sc.getConf().get("key1"), "value1") - finally: - if session is not None: - session.stop() - if sc is not None: - sc.stop() - - def test_another_spark_session(self): - session1 = None - session2 = None - try: - session1 = SparkSession.builder.config("key1", "value1").getOrCreate() - session2 = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session1.conf.get("key1"), "value1") - self.assertEqual(session2.conf.get("key1"), "value1") - self.assertEqual(session1.conf.get("key2"), "value2") - self.assertEqual(session2.conf.get("key2"), "value2") - self.assertEqual(session1.sparkContext, session2.sparkContext) - - self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") - self.assertFalse(session1.sparkContext.getConf().contains("key2")) - finally: - if session1 is not None: - session1.stop() - if session2 is not None: - session2.stop() - - -class SQLTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - cls.spark.catalog._reset() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.spark.createDataFrame(cls.testData) - - @classmethod - def tearDownClass(cls): - ReusedSQLTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_sqlcontext_reuses_sparksession(self): - sqlContext1 = SQLContext(self.sc) - sqlContext2 = SQLContext(self.sc) - self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) - - def test_row_should_be_read_only(self): - row = Row(a=1, b=2) - self.assertEqual(1, row.a) - - def foo(): - row.a = 3 - self.assertRaises(Exception, foo) - - row2 = self.spark.range(10).first() - self.assertEqual(0, row2.id) - - def foo2(): - row2.id = 2 - self.assertRaises(Exception, foo2) - - def test_range(self): - self.assertEqual(self.spark.range(1, 1).count(), 0) - self.assertEqual(self.spark.range(1, 0, -1).count(), 1) - self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.spark.range(-2).count(), 0) - self.assertEqual(self.spark.range(3).count(), 3) - - def test_duplicated_column_names(self): - df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) - row = df.select('*').first() - self.assertEqual(1, row[0]) - self.assertEqual(2, row[1]) - self.assertEqual("Row(c=1, c=2)", str(row)) - # Cannot access columns - self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) - self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) - self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) - - def test_column_name_encoding(self): - """Ensure that created columns has `str` type consistently.""" - columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns - self.assertEqual(columns, ['name', 'age']) - self.assertTrue(isinstance(columns[0], str)) - self.assertTrue(isinstance(columns[1], str)) - - def test_explode(self): - from pyspark.sql.functions import explode, explode_outer, posexplode_outer - d = [ - Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), - Row(a=1, intlist=[], mapfield={}), - Row(a=1, intlist=None, mapfield=None), - ] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - result = data.select(explode(data.intlist).alias("a")).select("a").collect() - self.assertEqual(result[0][0], 1) - self.assertEqual(result[1][0], 2) - self.assertEqual(result[2][0], 3) - - result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() - self.assertEqual(result[0][0], "a") - self.assertEqual(result[0][1], "b") - - result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] - self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) - - result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] - self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) - - result = [x[0] for x in data.select(explode_outer("intlist")).collect()] - self.assertEqual(result, [1, 2, 3, None, None]) - - result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] - self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) - - def test_and_in_expression(self): - self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) - self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) - self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) - self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") - self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) - self.assertRaises(ValueError, lambda: not self.df.key == 1) - - def test_udf_with_callable(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - class PlusFour: - def __call__(self, col): - if col is not None: - return col + 4 - - call = PlusFour() - pudf = UserDefinedFunction(call, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf_with_partial_function(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - def some_func(col, param): - if col is not None: - return col + param - - pfunc = functools.partial(some_func, param=4) - pudf = UserDefinedFunction(pfunc, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf(self): - self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. - sqlContext = self.spark._wrapped - sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) - [row] = sqlContext.sql("SELECT oneArg('test')").collect() - self.assertEqual(row[0], 4) - - def test_udf2(self): - with self.tempView("test"): - self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ - .createOrReplaceTempView("test") - [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf3(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], u'5') - - def test_udf_registration_return_type_none(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf_registration_return_type_not_none(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, "Invalid returnType"): - self.spark.catalog.registerFunction( - "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) - - def test_nondeterministic_udf(self): - # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - self.assertEqual(udf_random_col.deterministic, False) - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - - def test_nondeterministic_udf2(self): - import random - from pyspark.sql.functions import udf - random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() - self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) - self.assertEqual(random_udf1.deterministic, False) - [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf()).collect() - self.assertEqual(row[0], 6) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) - pydoc.render_doc(random_udf) - pydoc.render_doc(random_udf1) - pydoc.render_doc(udf(lambda x: x).asNondeterministic) - - def test_nondeterministic_udf3(self): - # regression test for SPARK-23233 - from pyspark.sql.functions import udf - f = udf(lambda x: x) - # Here we cache the JVM UDF instance. - self.spark.range(1).select(f("id")) - # This should reset the cache to set the deterministic status correctly. - f = f.asNondeterministic() - # Check the deterministic status of udf. - df = self.spark.range(1).select(f("id")) - deterministic = df._jdf.logicalPlan().projectList().head().deterministic() - self.assertFalse(deterministic) - - def test_nondeterministic_udf_in_aggregate(self): - from pyspark.sql.functions import udf, sum - import random - udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() - df = self.spark.range(10) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.groupby('id').agg(sum(udf_random_col())).collect() - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.agg(sum(udf_random_col())).collect() - - def test_chained_udf(self): - self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.spark.sql("SELECT double(1)").collect() - self.assertEqual(row[0], 2) - [row] = self.spark.sql("SELECT double(double(1))").collect() - self.assertEqual(row[0], 4) - [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() - self.assertEqual(row[0], 6) - - def test_single_udf_with_repeated_argument(self): - # regression test for SPARK-20685 - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - row = self.spark.sql("SELECT add(1, 1)").first() - self.assertEqual(tuple(row), (2, )) - - def test_multiple_udfs(self): - self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.spark.sql("SELECT double(1), double(2)").collect() - self.assertEqual(tuple(row), (2, 4)) - [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() - self.assertEqual(tuple(row), (4, 12)) - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() - self.assertEqual(tuple(row), (6, 5)) - - def test_udf_in_filter_on_top_of_outer_join(self): - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(a=1)]) - df = left.join(right, on='a', how='left_outer') - df = df.withColumn('b', udf(lambda x: 'x')(df.a)) - self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) - - def test_udf_in_filter_on_top_of_join(self): - # regression test for SPARK-18589 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.crossJoin(right).filter(f("a", "b")) - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b")) - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b"), "leftsemi") - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_and_common_filter_in_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1]) - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) - - def test_udf_and_common_filter_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_not_supported_in_join_condition(self): - # regression test for SPARK-25314 - # test python udf is not supported in join type besides left_semi and inner join. - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - - def runWithJoinType(join_type, type_string): - with self.assertRaisesRegexp( - AnalysisException, - 'Using PythonUDF.*%s is not supported.' % type_string): - left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() - runWithJoinType("full", "FullOuter") - runWithJoinType("left", "LeftOuter") - runWithJoinType("right", "RightOuter") - runWithJoinType("leftanti", "LeftAnti") - - def test_udf_without_arguments(self): - self.spark.catalog.registerFunction("foo", lambda: "bar") - [row] = self.spark.sql("SELECT foo()").collect() - self.assertEqual(row[0], "bar") - - def test_udf_with_array_type(self): - with self.tempView("test"): - d = [Row(l=list(range(3)), d={"key": list(range(5))})] - rdd = self.sc.parallelize(d) - self.spark.createDataFrame(rdd).createOrReplaceTempView("test") - self.spark.catalog.registerFunction( - "copylist", lambda l: list(l), ArrayType(IntegerType())) - self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(list(range(3)), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.spark.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.spark.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_udf_with_filter_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a < 2, BooleanType()) - sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) - self.assertEqual(sel.collect(), [Row(key=1, value='1')]) - - def test_udf_with_aggregate_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col, sum - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a == 1, BooleanType()) - sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) - self.assertEqual(sel.collect(), [Row(key=1)]) - - my_copy = udf(lambda x: x, IntegerType()) - my_add = udf(lambda a, b: int(a + b), IntegerType()) - my_strlen = udf(lambda x: len(x), IntegerType()) - sel = df.groupBy(my_copy(col("key")).alias("k"))\ - .agg(sum(my_strlen(col("value"))).alias("s"))\ - .select(my_add(col("k"), col("s")).alias("t")) - self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) - - def test_udf_in_generate(self): - from pyspark.sql.functions import udf, explode - df = self.spark.range(5) - f = udf(lambda x: list(range(x)), ArrayType(LongType())) - row = df.select(explode(f(*df))).groupBy().sum().first() - self.assertEqual(row[0], 10) - - df = self.spark.range(3) - res = df.select("id", explode(f(df.id))).collect() - self.assertEqual(res[0][0], 1) - self.assertEqual(res[0][1], 0) - self.assertEqual(res[1][0], 2) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 2) - self.assertEqual(res[2][1], 1) - - range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) - res = df.select("id", explode(range_udf(df.id))).collect() - self.assertEqual(res[0][0], 0) - self.assertEqual(res[0][1], -1) - self.assertEqual(res[1][0], 0) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 1) - self.assertEqual(res[2][1], 0) - self.assertEqual(res[3][0], 1) - self.assertEqual(res[3][1], 1) - - def test_udf_with_order_by_and_limit(self): - from pyspark.sql.functions import udf - my_copy = udf(lambda x: x, IntegerType()) - df = self.spark.range(10).orderBy("id") - res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) - res.explain(True) - self.assertEqual(res.collect(), [Row(id=0, copy=0)]) - - def test_udf_registration_returns_udf(self): - df = self.spark.range(10) - add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_three(id) AS plus_three").collect(), - df.select(add_three("id").alias("plus_three")).collect() - ) - - # This is to check if a 'SQLContext.udf' can call its alias. - sqlContext = self.spark._wrapped - add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_four(id) AS plus_four").collect(), - df.select(add_four("id").alias("plus_four")).collect() - ) - - def test_non_existed_udf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) - - # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. - sqlContext = spark._wrapped - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) - - def test_non_existed_udaf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", - lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - - def test_linesep_text(self): - df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") - expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), - Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), - Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), - Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df.write.text(tpath, lineSep="!") - expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), - Row(value=u'Tom!30!"My name is Tom"'), - Row(value=u'Hyukjin!25!"I am Hyukjin'), - Row(value=u''), Row(value=u'I love Spark!"'), - Row(value=u'!')] - readback = self.spark.read.text(tpath) - self.assertEqual(readback.collect(), expected) - finally: - shutil.rmtree(tpath) - - def test_multiline_json(self): - people1 = self.spark.read.json("python/test_support/sql/people.json") - people_array = self.spark.read.json("python/test_support/sql/people_array.json", - multiLine=True) - self.assertEqual(people1.collect(), people_array.collect()) - - def test_encoding_json(self): - people_array = self.spark.read\ - .json("python/test_support/sql/people_array_utf16le.json", - multiLine=True, encoding="UTF-16LE") - expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] - self.assertEqual(people_array.collect(), expected) - - def test_linesep_json(self): - df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") - expected = [Row(_corrupt_record=None, name=u'Michael'), - Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), - Row(_corrupt_record=u' "age":19}\n', name=None)] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df = self.spark.read.json("python/test_support/sql/people.json") - df.write.json(tpath, lineSep="!!") - readback = self.spark.read.json(tpath, lineSep="!!") - self.assertEqual(readback.collect(), df.collect()) - finally: - shutil.rmtree(tpath) - - def test_multiline_csv(self): - ages_newlines = self.spark.read.csv( - "python/test_support/sql/ages_newlines.csv", multiLine=True) - expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), - Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), - Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] - self.assertEqual(ages_newlines.collect(), expected) - - def test_ignorewhitespace_csv(self): - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( - tmpPath, - ignoreLeadingWhiteSpace=False, - ignoreTrailingWhiteSpace=False) - - expected = [Row(value=u' a,b , c ')] - readback = self.spark.read.text(tmpPath) - self.assertEqual(readback.collect(), expected) - shutil.rmtree(tmpPath) - - def test_read_multiple_orc_file(self): - df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", - "python/test_support/sql/orc_partitioned/b=1/c=1"]) - self.assertEqual(2, df.count()) - - def test_udf_with_input_file_name(self): - from pyspark.sql.functions import udf, input_file_name - sourceFile = udf(lambda path: path, StringType()) - filePath = "python/test_support/sql/people1.json" - row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() - self.assertTrue(row[0].find("people1.json") != -1) - - def test_udf_with_input_file_name_for_hadooprdd(self): - from pyspark.sql.functions import udf, input_file_name - - def filename(path): - return path - - sameText = udf(filename, StringType()) - - rdd = self.sc.textFile('python/test_support/sql/people.json') - df = self.spark.read.json(rdd).select(input_file_name().alias('file')) - row = df.select(sameText(df['file'])).first() - self.assertTrue(row[0].find("people.json") != -1) - - rdd2 = self.sc.newAPIHadoopFile( - 'python/test_support/sql/people.json', - 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', - 'org.apache.hadoop.io.LongWritable', - 'org.apache.hadoop.io.Text') - - df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) - row2 = df2.select(sameText(df2['file'])).first() - self.assertTrue(row2[0].find("people.json") != -1) - - def test_udf_defers_judf_initialization(self): - # This is separate of UDFInitializationTests - # to avoid context initialization - # when udf is called - - from pyspark.sql.functions import UserDefinedFunction - - f = UserDefinedFunction(lambda x: x, StringType()) - - self.assertIsNone( - f._judf_placeholder, - "judf should not be initialized before the first call." - ) - - self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") - - self.assertIsNotNone( - f._judf_placeholder, - "judf should be initialized after UDF has been called." - ) - - def test_udf_with_string_return_type(self): - from pyspark.sql.functions import UserDefinedFunction - - add_one = UserDefinedFunction(lambda x: x + 1, "integer") - make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") - make_array = UserDefinedFunction( - lambda x: [float(x) for x in range(x, x + 3)], "array") - - expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) - actual = (self.spark.range(1, 2).toDF("x") - .select(add_one("x"), make_pair("x"), make_array("x")) - .first()) - - self.assertTupleEqual(expected, actual) - - def test_udf_shouldnt_accept_noncallable_object(self): - from pyspark.sql.functions import UserDefinedFunction - - non_callable = None - self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) - - def test_udf_with_decorator(self): - from pyspark.sql.functions import lit, udf - from pyspark.sql.types import IntegerType, DoubleType - - @udf(IntegerType()) - def add_one(x): - if x is not None: - return x + 1 - - @udf(returnType=DoubleType()) - def add_two(x): - if x is not None: - return float(x + 2) - - @udf - def to_upper(x): - if x is not None: - return x.upper() - - @udf() - def to_lower(x): - if x is not None: - return x.lower() - - @udf - def substr(x, start, end): - if x is not None: - return x[start:end] - - @udf("long") - def trunc(x): - return int(x) - - @udf(returnType="double") - def as_double(x): - return float(x) - - df = ( - self.spark - .createDataFrame( - [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) - .select( - add_one("one"), add_two("one"), - to_upper("Foo"), to_lower("Foo"), - substr("foobar", lit(0), lit(3)), - trunc("float"), as_double("one"))) - - self.assertListEqual( - [tpe for _, tpe in df.dtypes], - ["int", "double", "string", "string", "string", "bigint", "double"] - ) - - self.assertListEqual( - list(df.first()), - [2, 3.0, "FOO", "foo", "foo", 3, 1.0] - ) - - def test_udf_wrapper(self): - from pyspark.sql.functions import udf - from pyspark.sql.types import IntegerType - - def f(x): - """Identity""" - return x - - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - class F(object): - """Identity""" - def __call__(self, x): - return x - - f = F() - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - f = functools.partial(f, x=1) - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - def test_validate_column_types(self): - from pyspark.sql.functions import udf, to_json - from pyspark.sql.column import _to_java_column - - self.assertTrue("Column" in _to_java_column("a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: _to_java_column(1)) - - class A(): - pass - - self.assertRaises(TypeError, lambda: _to_java_column(A())) - self.assertRaises(TypeError, lambda: _to_java_column([])) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: udf(lambda x: x)(None)) - self.assertRaises(TypeError, lambda: to_json(1)) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.spark.read.json(rdd) - df.count() - df.collect() - df.schema - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist(True) - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - with self.tempView("temp"): - df.createOrReplaceTempView("temp") - df = self.spark.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_infer_schema_to_local(self): - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - df = self.spark.createDataFrame(input) - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_apply_schema_to_dict_and_rows(self): - schema = StructType().add("b", StringType()).add("a", IntegerType()) - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - for verify in [False, True]: - df = self.spark.createDataFrame(input, schema, verifySchema=verify) - df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(10, df3.count()) - input = [Row(a=x, b=str(x)) for x in range(10)] - df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) - self.assertEqual(10, df4.count()) - - def test_create_dataframe_schema_mismatch(self): - input = [Row(a=1)] - rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) - schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.spark.createDataFrame(rdd, schema) - self.assertRaises(Exception, lambda: df.show()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.rdd.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.rdd.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.rdd.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}, s=None), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - self.assertEqual([], df.rdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test"): - df.createOrReplaceTempView("test") - result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test2"): - df2.createOrReplaceTempView("test2") - result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_infer_schema_specification(self): - from decimal import Decimal - - class A(object): - def __init__(self): - self.a = 1 - - data = [ - True, - 1, - "a", - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - array.array("d", [1]), - [1], - (1, ), - {"a": 1}, - bytearray(1), - Decimal(1), - Row(a=1), - Row("a")(1), - A(), - ] - - df = self.spark.createDataFrame([data]) - actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) - expected = [ - 'boolean', - 'bigint', - 'string', - 'string', - 'date', - 'timestamp', - 'double', - 'array', - 'array', - 'struct<_1:bigint>', - 'map', - 'binary', - 'decimal(38,18)', - 'struct', - 'struct', - 'struct', - ] - self.assertEqual(actual, expected) - - actual = list(df.first()) - expected = [ - True, - 1, - 'a', - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - [1.0], - [1], - Row(_1=1), - {"a": 1}, - bytearray(b'\x00'), - Decimal('1.000000000000000000'), - Row(a=1), - Row(a=1), - Row(a=1), - ] - self.assertEqual(actual, expected) - - def test_infer_schema_not_enough_names(self): - df = self.spark.createDataFrame([["a", "b"]], ["col1"]) - self.assertEqual(df.columns, ['col1', '_2']) - - def test_infer_schema_fails(self): - with self.assertRaisesRegexp(TypeError, 'field a'): - self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), - schema=["a", "b"], samplingRatio=0.99) - - def test_infer_nested_schema(self): - NestedRow = Row("f1", "f2") - nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), - NestedRow([2, 3], {"row2": 2.0})]) - df = self.spark.createDataFrame(nestedRdd1) - self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) - - nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), - NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.spark.createDataFrame(nestedRdd2) - self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) - - from collections import namedtuple - CustomRow = namedtuple('CustomRow', 'field1 field2') - rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), - CustomRow(field1=2, field2="row2"), - CustomRow(field1=3, field2="row3")]) - df = self.spark.createDataFrame(rdd) - self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) - - def test_create_dataframe_from_dict_respects_schema(self): - df = self.spark.createDataFrame([{'a': 1}], ["b"]) - self.assertEqual(df.columns, ['b']) - - def test_create_dataframe_from_objects(self): - data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.spark.createDataFrame(data) - self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) - self.assertEqual(df.first(), Row(key=1, value="1")) - - def test_select_null_literal(self): - df = self.spark.sql("select null as col") - self.assertEqual(Row(col=None), df.first()) - - def test_apply_schema(self): - from datetime import date, datetime - rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, - date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), - {"a": 1}, (2,), [1, 2, 3], None)]) - schema = StructType([ - StructField("byte1", ByteType(), False), - StructField("byte2", ByteType(), False), - StructField("short1", ShortType(), False), - StructField("short2", ShortType(), False), - StructField("int1", IntegerType(), False), - StructField("float1", FloatType(), False), - StructField("date1", DateType(), False), - StructField("time1", TimestampType(), False), - StructField("map1", MapType(StringType(), IntegerType(), False), False), - StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), - StructField("list1", ArrayType(ByteType(), False), False), - StructField("null1", DoubleType(), True)]) - df = self.spark.createDataFrame(rdd, schema) - results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, - x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) - r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), - datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - self.assertEqual(r, results.first()) - - with self.tempView("table2"): - df.createOrReplaceTempView("table2") - r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() - - self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - df = self.sc.parallelize(d).toDF() - k, v = list(df.head().m.items())[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - df = self.sc.parallelize([row]).toDF() - - with self.tempView("test"): - df.createOrReplaceTempView("test") - row = self.spark.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - - def check_datatype(datatype): - pickled = pickle.loads(pickle.dumps(datatype)) - assert datatype == pickled - scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) - python_datatype = _parse_datatype_json_string(scala_datatype.json()) - assert datatype == python_datatype - - check_datatype(ExamplePointUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - check_datatype(structtype_with_udt) - p = ExamplePoint(1.0, 2.0) - self.assertEqual(_infer_type(p), ExamplePointUDT()) - _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) - self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) - - check_datatype(PythonOnlyUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - check_datatype(structtype_with_udt) - p = PythonOnlyPoint(1.0, 2.0) - self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) - self.assertRaises( - ValueError, - lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) - - def test_simple_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - def test_nested_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], - schema=schema) - df.collect() - - schema = StructType().add("key", LongType()).add("val", - MapType(LongType(), PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], - schema=schema) - df.collect() - - def test_complex_nested_udt_in_df(self): - from pyspark.sql.functions import udf - - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - gd = df.groupby("key").agg({"val": "collect_list"}) - gd.collect() - udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) - gd.select(udf(*gd)).collect() - - def test_udt_with_none(self): - df = self.spark.range(0, 10, 1, 1) - - def myudf(x): - if x > 0: - return PythonOnlyPoint(float(x), float(x)) - - self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) - rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] - self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) - - def test_nonparam_udf_with_aggregate(self): - import pyspark.sql.functions as f - - df = self.spark.createDataFrame([(1, 2), (1, 2)]) - f_udf = f.udf(lambda: "const_str") - rows = df.distinct().withColumn("a", f_udf()).collect() - self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) - - def test_infer_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), PythonOnlyUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = (1.0, PythonOnlyPoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) - self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) - self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.write.parquet(output_dir) - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - df0.write.parquet(output_dir, mode='overwrite') - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_union_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row1 = (1.0, ExamplePoint(1.0, 2.0)) - row2 = (2.0, ExamplePoint(3.0, 4.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df1 = self.spark.createDataFrame([row1], schema) - df2 = self.spark.createDataFrame([row2], schema) - - result = df1.union(df2).orderBy("label").collect() - self.assertEqual( - result, - [ - Row(label=1.0, point=ExamplePoint(1.0, 2.0)), - Row(label=2.0, point=ExamplePoint(3.0, 4.0)) - ] - ) - - def test_cast_to_string_with_udt(self): - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - from pyspark.sql.functions import col - row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) - schema = StructType([StructField("point", ExamplePointUDT(), False), - StructField("pypoint", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - - result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() - self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) - - def test_column_operators(self): - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbool = (ci & ci), (ci | ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - self.assertRaisesRegexp(ValueError, - "Cannot apply 'in' operator against a column", - lambda: 1 in cs) - - def test_column_getitem(self): - from pyspark.sql.functions import col - - self.assertIsInstance(col("foo")[1:3], Column) - self.assertIsInstance(col("foo")[0], Column) - self.assertIsInstance(col("foo")["bar"], Column) - self.assertRaises(ValueError, lambda: col("foo")[0:10:2]) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_freqItems(self): - vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] - df = self.sc.parallelize(vals).toDF() - items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] - self.assertTrue(1 in items[0]) - self.assertTrue(-2.0 in items[1]) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import functions - self.assertEqual((0, u'99'), - tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) - self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) - - def test_first_last_ignorenulls(self): - from pyspark.sql import functions - df = self.spark.range(0, 100) - df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) - df3 = df2.select(functions.first(df2.id, False).alias('a'), - functions.first(df2.id, True).alias('b'), - functions.last(df2.id, False).alias('c'), - functions.last(df2.id, True).alias('d')) - self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) - - def test_approxQuantile(self): - df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - for f in ["a", u"a"]: - aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) - self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqs, list)) - self.assertEqual(len(aqs), 2) - self.assertTrue(isinstance(aqs[0], list)) - self.assertEqual(len(aqs[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[0])) - self.assertTrue(isinstance(aqs[1], list)) - self.assertEqual(len(aqs[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqt, list)) - self.assertEqual(len(aqt), 2) - self.assertTrue(isinstance(aqt[0], list)) - self.assertEqual(len(aqt[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[0])) - self.assertTrue(isinstance(aqt[1], list)) - self.assertEqual(len(aqt[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[1])) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) - - def test_corr(self): - import math - df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr(u"a", "b") - self.assertTrue(abs(corr - 0.95734012) < 1e-6) - - def test_sampleby(self): - df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() - sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) - self.assertTrue(sampled.count() == 3) - - def test_cov(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov(u"a", "b") - self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) - - def test_crosstab(self): - df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab(u"a", "b").collect() - ct = sorted(ct, key=lambda x: x[0]) - for i, row in enumerate(ct): - self.assertEqual(row[0], str(i)) - self.assertTrue(row[1], 1) - self.assertTrue(row[2], 1) - - def test_math_functions(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - from pyspark.sql import functions - import math - - def get_values(l): - return [j[0] for j in l] - - def assert_close(a, b): - c = get_values(b) - diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] - return sum(diff) == len(a) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos(df.a)).collect()) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos("a")).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df.a)).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df['a'])).collect()) - assert_close([math.pow(i, 2 * i) for i in range(10)], - df.select(functions.pow(df.a, df.b)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2.0)).collect()) - assert_close([math.hypot(i, 2 * i) for i in range(10)], - df.select(functions.hypot(df.a, df.b)).collect()) - - def test_rand_functions(self): - df = self.df - from pyspark.sql import functions - rnd = df.select('key', functions.rand()).collect() - for row in rnd: - assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] - rndn = df.select('key', functions.randn(5)).collect() - for row in rndn: - assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] - - # If the specified seed is 0, we should use it. - # https://issues.apache.org/jira/browse/SPARK-9691 - rnd1 = df.select('key', functions.rand(0)).collect() - rnd2 = df.select('key', functions.rand(0)).collect() - self.assertEqual(sorted(rnd1), sorted(rnd2)) - - rndn1 = df.select('key', functions.randn(0)).collect() - rndn2 = df.select('key', functions.randn(0)).collect() - self.assertEqual(sorted(rndn1), sorted(rndn2)) - - def test_string_functions(self): - from pyspark.sql.functions import col, lit - df = self.spark.createDataFrame([['nick']], schema=['name']) - self.assertRaisesRegexp( - TypeError, - "must be the same type", - lambda: df.select(col('name').substr(0, lit(1)))) - if sys.version_info.major == 2: - self.assertRaises( - TypeError, - lambda: df.select(col('name').substr(long(0), long(1)))) - - def test_array_contains_function(self): - from pyspark.sql.functions import array_contains - - df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) - actual = df.select(array_contains(df.data, "1").alias('b')).collect() - self.assertEqual([Row(b=True), Row(b=False)], actual) - - def test_between_function(self): - df = self.sc.parallelize([ - Row(a=1, b=2, c=3), - Row(a=2, b=1, c=3), - Row(a=4, b=1, c=4)]).toDF() - self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], - df.filter(df.a.between(df.b, df.c)).collect()) - - def test_struct_type(self): - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - # Catch exception raised during improper construction - self.assertRaises(ValueError, lambda: StructType().add("name")) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - for field in struct1: - self.assertIsInstance(field, StructField) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertEqual(len(struct1), 2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertIs(struct1["f1"], struct1.fields[0]) - self.assertIs(struct1[0], struct1.fields[0]) - self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - self.assertRaises(KeyError, lambda: struct1["f9"]) - self.assertRaises(IndexError, lambda: struct1[9]) - self.assertRaises(TypeError, lambda: struct1[9.9]) - - def test_struct_type_to_internal(self): - # Verify when not needSerializeAnyField - struct = StructType().add("b", StringType()).add("a", StringType()) - string_a = "value_a" - string_b = "value_b" - row = Row(a=string_a, b=string_b) - tupleResult = struct.toInternal(row) - # Reversed because of struct - self.assertEqual(tupleResult, (string_b, string_a)) - - # Verify when needSerializeAnyField - struct1 = StructType().add("b", TimestampType()).add("a", TimestampType()) - timestamp_a = datetime.datetime(2018, 1, 1, 1, 1, 1) - timestamp_b = datetime.datetime(2019, 1, 1, 1, 1, 1) - row = Row(a=timestamp_a, b=timestamp_b) - tupleResult = struct1.toInternal(row) - # Reversed because of struct - d = 1000000 - ts_b = tupleResult[0] - new_timestamp_b = datetime.datetime.fromtimestamp(ts_b // d).replace(microsecond=ts_b % d) - ts_a = tupleResult[1] - new_timestamp_a = datetime.datetime.fromtimestamp(ts_a // d).replace(microsecond=ts_a % d) - self.assertEqual(timestamp_a, new_timestamp_a) - self.assertEqual(timestamp_b, new_timestamp_b) - - def test_parse_datatype_string(self): - from pyspark.sql.types import _all_atomic_types, _parse_datatype_string - for k, t in _all_atomic_types.items(): - if t != NullType: - self.assertEqual(t(), _parse_datatype_string(k)) - self.assertEqual(IntegerType(), _parse_datatype_string("int")) - self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) - self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) - self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) - self.assertEqual( - ArrayType(IntegerType()), - _parse_datatype_string("array")) - self.assertEqual( - MapType(IntegerType(), DoubleType()), - _parse_datatype_string("map< int, double >")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("struct")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a:int, c:double")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a INT, c DOUBLE")) - - def test_metadata_null(self): - schema = StructType([StructField("f1", StringType(), True, None), - StructField("f2", StringType(), True, {'a': None})]) - rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.spark.createDataFrame(rdd, schema) - - def test_save_and_load(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.json(tmpPath, "overwrite") - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.save(format="json", mode="overwrite", path=tmpPath, - noUse="this options will not be used in save.") - actual = self.spark.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - csvpath = os.path.join(tempfile.mkdtemp(), 'data') - df.write.option('quote', None).format('csv').save(csvpath) - - shutil.rmtree(tmpPath) - - def test_save_and_load_builder(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.mode("overwrite").json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ - .option("noUse", "this option will not be used in save.")\ - .format("json").save(path=tmpPath) - actual =\ - self.spark.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) - - def test_stream_trigger(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - - # Should take at least one arg - try: - df.writeStream.trigger() - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(once=True, processingTime='5 seconds') - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') - except ValueError: - pass - - # Should take only keyword args - try: - df.writeStream.trigger('5 seconds') - self.fail("Should have thrown an exception") - except TypeError: - pass - - def test_stream_read_options(self): - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream\ - .format('text')\ - .option('path', 'python/test_support/sql/streaming')\ - .schema(schema)\ - .load() - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") - - def test_stream_read_options_overwrite(self): - bad_schema = StructType([StructField("test", IntegerType(), False)]) - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ - .schema(bad_schema)\ - .load(path='python/test_support/sql/streaming', schema=schema, format='text') - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") - - def test_stream_save_options(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ - .withColumn('id', lit(1)) - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_save_options_overwrite(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - fake1 = os.path.join(tmpPath, 'fake1') - fake2 = os.path.join(tmpPath, 'fake2') - q = df.writeStream.option('checkpointLocation', fake1)\ - .format('memory').option('path', fake2) \ - .queryName('fake_query').outputMode('append') \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - self.assertFalse(os.path.isdir(fake1)) # should not have been created - self.assertFalse(os.path.isdir(fake2)) # should not have been created - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_status_and_progress(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - - def func(x): - time.sleep(1) - return x - - from pyspark.sql.functions import col, udf - sleep_udf = udf(func) - - # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there - # were no updates. - q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - # "lastProgress" will return None in most cases. However, as it may be flaky when - # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" - # may throw error with a high chance and make this test flaky, so we should still be - # able to detect broken codes. - q.lastProgress - - q.processAllAvailable() - lastProgress = q.lastProgress - recentProgress = q.recentProgress - status = q.status - self.assertEqual(lastProgress['name'], q.name) - self.assertEqual(lastProgress['id'], q.id) - self.assertTrue(any(p == lastProgress for p in recentProgress)) - self.assertTrue( - "message" in status and - "isDataAvailable" in status and - "isTriggerActive" in status) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_await_termination(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream\ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - self.assertTrue(q.isActive) - try: - q.awaitTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - # test should take at least 2 seconds - res = q.awaitTermination(2.6) - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_exception(self): - sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - sq = sdf.writeStream.format('memory').queryName('query_explain').start() - try: - sq.processAllAvailable() - self.assertEqual(sq.exception(), None) - finally: - sq.stop() - - from pyspark.sql.functions import col, udf - from pyspark.sql.utils import StreamingQueryException - bad_udf = udf(lambda x: 1 / 0) - sq = sdf.select(bad_udf(col("value")))\ - .writeStream\ - .format('memory')\ - .queryName('this_query')\ - .start() - try: - # Process some data to fail the query - sq.processAllAvailable() - self.fail("bad udf should fail the query") - except StreamingQueryException as e: - # This is expected - self.assertTrue("ZeroDivisionError" in e.desc) - finally: - sq.stop() - self.assertTrue(type(sq.exception()) is StreamingQueryException) - self.assertTrue("ZeroDivisionError" in sq.exception().desc) - - def test_query_manager_await_termination(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream\ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - self.assertTrue(q.isActive) - try: - self.spark._wrapped.streams.awaitAnyTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - # test should take at least 2 seconds - res = self.spark._wrapped.streams.awaitAnyTermination(2.6) - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - finally: - q.stop() - shutil.rmtree(tmpPath) - - class ForeachWriterTester: - - def __init__(self, spark): - self.spark = spark - - def write_open_event(self, partitionId, epochId): - self._write_event( - self.open_events_dir, - {'partition': partitionId, 'epoch': epochId}) - - def write_process_event(self, row): - self._write_event(self.process_events_dir, {'value': 'text'}) - - def write_close_event(self, error): - self._write_event(self.close_events_dir, {'error': str(error)}) - - def write_input_file(self): - self._write_event(self.input_dir, "text") - - def open_events(self): - return self._read_events(self.open_events_dir, 'partition INT, epoch INT') - - def process_events(self): - return self._read_events(self.process_events_dir, 'value STRING') - - def close_events(self): - return self._read_events(self.close_events_dir, 'error STRING') - - def run_streaming_query_on_writer(self, writer, num_files): - self._reset() - try: - sdf = self.spark.readStream.format('text').load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - for i in range(num_files): - self.write_input_file() - sq.processAllAvailable() - finally: - self.stop_all() - - def assert_invalid_writer(self, writer, msg=None): - self._reset() - try: - sdf = self.spark.readStream.format('text').load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - self.write_input_file() - sq.processAllAvailable() - self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected - except Exception as e: - if msg: - assert msg in str(e), "%s not in %s" % (msg, str(e)) - - finally: - self.stop_all() - - def stop_all(self): - for q in self.spark._wrapped.streams.active: - q.stop() - - def _reset(self): - self.input_dir = tempfile.mkdtemp() - self.open_events_dir = tempfile.mkdtemp() - self.process_events_dir = tempfile.mkdtemp() - self.close_events_dir = tempfile.mkdtemp() - - def _read_events(self, dir, json): - rows = self.spark.read.schema(json).json(dir).collect() - dicts = [row.asDict() for row in rows] - return dicts - - def _write_event(self, dir, event): - import uuid - with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: - f.write("%s\n" % str(event)) - - def __getstate__(self): - return (self.open_events_dir, self.process_events_dir, self.close_events_dir) - - def __setstate__(self, state): - self.open_events_dir, self.process_events_dir, self.close_events_dir = state - - # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules - # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html - # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. - def test_streaming_foreach_with_simple_function(self): - tester = self.ForeachWriterTester(self.spark) - - def foreach_func(row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(foreach_func, 2) - self.assertEqual(len(tester.process_events()), 2) - - def test_streaming_foreach_with_basic_open_process_close(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) - return True - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - open_events = tester.open_events() - self.assertEqual(len(open_events), 2) - self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) - - self.assertEqual(len(tester.process_events()), 2) - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) - - def test_streaming_foreach_with_open_returning_false(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return False - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - self.assertEqual(len(tester.open_events()), 2) - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) - - def test_streaming_foreach_without_open_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 2) - - def test_streaming_foreach_without_close_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return True - - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 2) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_without_open_and_close_methods(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_with_process_throwing_error(self): - from pyspark.sql.utils import StreamingQueryException - - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - raise Exception("test error") - - def close(self, error): - tester.write_close_event(error) - - try: - tester.run_streaming_query_on_writer(ForeachWriter(), 1) - self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException as e: - # TODO: Verify whether original error message is inside the exception - pass - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - close_events = tester.close_events() - self.assertEqual(len(close_events), 1) - # TODO: Verify whether original error message is inside the exception - - def test_streaming_foreach_with_invalid_writers(self): - - tester = self.ForeachWriterTester(self.spark) - - def func_with_iterator_input(iter): - for x in iter: - print(x) - - tester.assert_invalid_writer(func_with_iterator_input) - - class WriterWithoutProcess: - def open(self, partition): - pass - - tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") - - class WriterWithNonCallableProcess(): - process = True - - tester.assert_invalid_writer(WriterWithNonCallableProcess(), - "'process' in provided object is not callable") - - class WriterWithNoParamProcess(): - def process(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamProcess()) - - # Abstract class for tests below - class WithProcess(): - def process(self, row): - pass - - class WriterWithNonCallableOpen(WithProcess): - open = True - - tester.assert_invalid_writer(WriterWithNonCallableOpen(), - "'open' in provided object is not callable") - - class WriterWithNoParamOpen(WithProcess): - def open(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamOpen()) - - class WriterWithNonCallableClose(WithProcess): - close = True - - tester.assert_invalid_writer(WriterWithNonCallableClose(), - "'close' in provided object is not callable") - - def test_streaming_foreachBatch(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() - - try: - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_propagates_python_errors(self): - from pyspark.sql.utils import StreamingQueryException - - q = None - - def collectBatch(df, id): - raise Exception("this should fail the query") - - try: - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.fail("Expected a failure") - except StreamingQueryException as e: - self.assertTrue("this should fail" in str(e)) - finally: - if q: - q.stop() - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.spark.read.json(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - def test_access_column(self): - df = self.df - self.assertTrue(isinstance(df.key, Column)) - self.assertTrue(isinstance(df['key'], Column)) - self.assertTrue(isinstance(df[0], Column)) - self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(AnalysisException, lambda: df["bad_key"]) - self.assertRaises(TypeError, lambda: df[{}]) - - def test_column_name_with_non_ascii(self): - if sys.version >= '3': - columnName = "数量" - self.assertTrue(isinstance(columnName, str)) - else: - columnName = unicode("数量", "utf-8") - self.assertTrue(isinstance(columnName, unicode)) - schema = StructType([StructField(columnName, LongType(), True)]) - df = self.spark.createDataFrame([(1,)], schema) - self.assertEqual(schema, df.schema) - self.assertEqual("DataFrame[数量: bigint]", str(df)) - self.assertEqual([("数量", 'bigint')], df.dtypes) - self.assertEqual(1, df.select("数量").first()[0]) - self.assertEqual(1, df.select(df["数量"]).first()[0]) - - def test_access_nested_types(self): - df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() - self.assertEqual(1, df.select(df.l[0]).first()[0]) - self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) - self.assertEqual(1, df.select(df.r.a).first()[0]) - self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) - self.assertEqual("v", df.select(df.d["k"]).first()[0]) - self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) - - def test_field_accessor(self): - df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() - self.assertEqual(1, df.select(df.l[0]).first()[0]) - self.assertEqual(1, df.select(df.r["a"]).first()[0]) - self.assertEqual(1, df.select(df["r.a"]).first()[0]) - self.assertEqual("b", df.select(df.r["b"]).first()[0]) - self.assertEqual("b", df.select(df["r.b"]).first()[0]) - self.assertEqual("v", df.select(df.d["k"]).first()[0]) - - def test_infer_long_type(self): - longrow = [Row(f1='a', f2=100000000000000)] - df = self.sc.parallelize(longrow).toDF() - self.assertEqual(df.schema.fields[1].dataType, LongType()) - - # this saving as Parquet caused issues as well. - output_dir = os.path.join(self.tempdir.name, "infer_long_type") - df.write.parquet(output_dir) - df1 = self.spark.read.parquet(output_dir) - self.assertEqual('a', df1.first().f1) - self.assertEqual(100000000000000, df1.first().f2) - - self.assertEqual(_infer_type(1), LongType()) - self.assertEqual(_infer_type(2**10), LongType()) - self.assertEqual(_infer_type(2**20), LongType()) - self.assertEqual(_infer_type(2**31 - 1), LongType()) - self.assertEqual(_infer_type(2**31), LongType()) - self.assertEqual(_infer_type(2**61), LongType()) - self.assertEqual(_infer_type(2**71), LongType()) - - def test_merge_type(self): - self.assertEqual(_merge_type(LongType(), NullType()), LongType()) - self.assertEqual(_merge_type(NullType(), LongType()), LongType()) - - self.assertEqual(_merge_type(LongType(), LongType()), LongType()) - - self.assertEqual(_merge_type( - ArrayType(LongType()), - ArrayType(LongType()) - ), ArrayType(LongType())) - with self.assertRaisesRegexp(TypeError, 'element in array'): - _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) - - self.assertEqual(_merge_type( - MapType(StringType(), LongType()), - MapType(StringType(), LongType()) - ), MapType(StringType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'key of map'): - _merge_type( - MapType(StringType(), LongType()), - MapType(DoubleType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'value of map'): - _merge_type( - MapType(StringType(), LongType()), - MapType(StringType(), DoubleType())) - - self.assertEqual(_merge_type( - StructType([StructField("f1", LongType()), StructField("f2", StringType())]), - StructType([StructField("f1", LongType()), StructField("f2", StringType())]) - ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'field f1'): - _merge_type( - StructType([StructField("f1", LongType()), StructField("f2", StringType())]), - StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) - ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) - with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): - _merge_type( - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), - StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), - StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) - ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'element in array field f1'): - _merge_type( - StructType([ - StructField("f1", ArrayType(LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", ArrayType(DoubleType())), - StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]) - ), StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'value of map field f1'): - _merge_type( - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", MapType(StringType(), DoubleType())), - StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) - ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) - with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): - _merge_type( - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), - StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) - ) - - def test_filter_with_datetime(self): - time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) - date = time.date() - row = Row(date=date, time=time) - df = self.spark.createDataFrame([row]) - self.assertEqual(1, df.filter(df.date == date).count()) - self.assertEqual(1, df.filter(df.time == time).count()) - self.assertEqual(0, df.filter(df.date > date).count()) - self.assertEqual(0, df.filter(df.time > time).count()) - - def test_filter_with_datetime_timezone(self): - dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) - dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) - row = Row(date=dt1) - df = self.spark.createDataFrame([row]) - self.assertEqual(0, df.filter(df.date == dt2).count()) - self.assertEqual(1, df.filter(df.date > dt2).count()) - self.assertEqual(0, df.filter(df.date < dt2).count()) - - def test_time_with_timezone(self): - day = datetime.date.today() - now = datetime.datetime.now() - ts = time.mktime(now.timetuple()) - # class in __main__ is not serializable - from pyspark.sql.tests import UTCOffsetTimezone - utc = UTCOffsetTimezone() - utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds - # add microseconds to utcnow (keeping year,month,day,hour,minute,second) - utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) - df = self.spark.createDataFrame([(day, now, utcnow)]) - day1, now1, utcnow1 = df.first() - self.assertEqual(day1, day) - self.assertEqual(now, now1) - self.assertEqual(now, utcnow1) - - # regression test for SPARK-19561 - def test_datetime_at_epoch(self): - epoch = datetime.datetime.fromtimestamp(0) - df = self.spark.createDataFrame([Row(date=epoch)]) - first = df.select('date', lit(epoch).alias('lit_date')).first() - self.assertEqual(first['date'], epoch) - self.assertEqual(first['lit_date'], epoch) - - def test_dayofweek(self): - from pyspark.sql.functions import dayofweek - dt = datetime.datetime(2017, 11, 6) - df = self.spark.createDataFrame([Row(date=dt)]) - row = df.select(dayofweek(df.date)).first() - self.assertEqual(row[0], 2) - - def test_decimal(self): - from decimal import Decimal - schema = StructType([StructField("decimal", DecimalType(10, 5))]) - df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) - row = df.select(df.decimal + 1).first() - self.assertEqual(row[0], Decimal("4.14159")) - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.parquet(tmpPath) - df2 = self.spark.read.parquet(tmpPath) - row = df2.first() - self.assertEqual(row[0], Decimal("3.14159")) - - def test_dropna(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - # shouldn't drop a non-null row - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, 80.1)], schema).dropna().count(), - 1) - - # dropping rows with a single null value - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna().count(), - 0) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), - 0) - - # if how = 'all', only drop rows if all values are null - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(None, None, None)], schema).dropna(how='all').count(), - 0) - - # how and subset - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), - 0) - - # threshold - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, None)], schema).dropna(thresh=2).count(), - 0) - - # threshold and subset - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), - 0) - - # thresh should take precedence over how - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna( - how='any', thresh=2, subset=['name', 'age']).count(), - 1) - - def test_fillna(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True), - StructField("spy", BooleanType(), True)]) - - # fillna shouldn't change non-null values - row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() - self.assertEqual(row.age, 10) - - # fillna with int - row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() - self.assertEqual(row.age, 50) - self.assertEqual(row.height, 50.0) - - # fillna with double - row = self.spark.createDataFrame( - [(u'Alice', None, None, None)], schema).fillna(50.1).first() - self.assertEqual(row.age, 50) - self.assertEqual(row.height, 50.1) - - # fillna with bool - row = self.spark.createDataFrame( - [(u'Alice', None, None, None)], schema).fillna(True).first() - self.assertEqual(row.age, None) - self.assertEqual(row.spy, True) - - # fillna with string - row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() - self.assertEqual(row.name, u"hello") - self.assertEqual(row.age, None) - - # fillna with subset specified for numeric cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() - self.assertEqual(row.name, None) - self.assertEqual(row.age, 50) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, None) - - # fillna with subset specified for string cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() - self.assertEqual(row.name, "haha") - self.assertEqual(row.age, None) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, None) - - # fillna with subset specified for bool cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() - self.assertEqual(row.name, None) - self.assertEqual(row.age, None) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, True) - - # fillna with dictionary for boolean types - row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() - self.assertEqual(row.a, True) - - def test_bitwise_operations(self): - from pyspark.sql import functions - row = Row(a=170, b=75) - df = self.spark.createDataFrame([row]) - result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() - self.assertEqual(170 & 75, result['(a & b)']) - result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() - self.assertEqual(170 | 75, result['(a | b)']) - result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() - self.assertEqual(170 ^ 75, result['(a ^ b)']) - result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() - self.assertEqual(~75, result['~b']) - - def test_expr(self): - from pyspark.sql import functions - row = Row(a="length string", b=75) - df = self.spark.createDataFrame([row]) - result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["length(a)"]) - - def test_repartitionByRange_dataframe(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - df1 = self.spark.createDataFrame( - [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) - df2 = self.spark.createDataFrame( - [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) - - # test repartitionByRange(numPartitions, *cols) - df3 = df1.repartitionByRange(2, "name", "age") - self.assertEqual(df3.rdd.getNumPartitions(), 2) - self.assertEqual(df3.rdd.first(), df2.rdd.first()) - self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) - - # test repartitionByRange(numPartitions, *cols) - df4 = df1.repartitionByRange(3, "name", "age") - self.assertEqual(df4.rdd.getNumPartitions(), 3) - self.assertEqual(df4.rdd.first(), df2.rdd.first()) - self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) - - # test repartitionByRange(*cols) - df5 = df1.repartitionByRange("name", "age") - self.assertEqual(df5.rdd.first(), df2.rdd.first()) - self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) - - def test_replace(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - # replace with int - row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() - self.assertEqual(row.age, 20) - self.assertEqual(row.height, 20.0) - - # replace with double - row = self.spark.createDataFrame( - [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() - self.assertEqual(row.age, 82) - self.assertEqual(row.height, 82.1) - - # replace with string - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() - self.assertEqual(row.name, u"Ann") - self.assertEqual(row.age, 10) - - # replace with subset specified by a string of a column name w/ actual change - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() - self.assertEqual(row.age, 20) - - # replace with subset specified by a string of a column name w/o actual change - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() - self.assertEqual(row.age, 10) - - # replace with subset specified with one column replaced, another column not in subset - # stays unchanged. - row = self.spark.createDataFrame( - [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() - self.assertEqual(row.name, u'Alice') - self.assertEqual(row.age, 20) - self.assertEqual(row.height, 10.0) - - # replace with subset specified but no column will be replaced - row = self.spark.createDataFrame( - [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() - self.assertEqual(row.name, u'Alice') - self.assertEqual(row.age, 10) - self.assertEqual(row.height, None) - - # replace with lists - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() - self.assertTupleEqual(row, (u'Ann', 10, 80.1)) - - # replace with dict - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() - self.assertTupleEqual(row, (u'Alice', 11, 80.1)) - - # test backward compatibility with dummy value - dummy_value = 1 - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() - self.assertTupleEqual(row, (u'Bob', 10, 80.1)) - - # test dict with mixed numerics - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() - self.assertTupleEqual(row, (u'Alice', -10, 90.5)) - - # replace with tuples - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() - self.assertTupleEqual(row, (u'Bob', 10, 80.1)) - - # replace multiple columns - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.0)) - - # test for mixed numerics - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.5)) - - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.5)) - - # replace with boolean - row = (self - .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) - .selectExpr("name = 'Bob'", 'age <= 15') - .replace(False, True).first()) - self.assertTupleEqual(row, (True, True)) - - # replace string with None and then drop None rows - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() - self.assertEqual(row.count(), 0) - - # replace with number and None - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() - self.assertTupleEqual(row, (u'Alice', 20, None)) - - # should fail if subset is not list, tuple or None - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() - - # should fail if to_replace and value have different length - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() - - # should fail if when received unexpected type - with self.assertRaises(ValueError): - from datetime import datetime - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() - - # should fail if provided mixed type replacements - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() - - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() - - with self.assertRaisesRegexp( - TypeError, - 'value argument is required when to_replace is not a dictionary.'): - self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - - def test_capture_analysis_exception(self): - self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) - self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - - def test_capture_parse_exception(self): - self.assertRaises(ParseException, lambda: self.spark.sql("abc")) - - def test_capture_illegalargument_exception(self): - self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", - lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) - df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) - self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", - lambda: df.select(sha2(df.a, 1024)).collect()) - try: - df.select(sha2(df.a, 1024)).collect() - except IllegalArgumentException as e: - self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") - self.assertRegexpMatches(e.stackTrace, - "org.apache.spark.sql.functions") - - def test_with_column_with_existing_name(self): - keys = self.df.withColumn("key", self.df.key).select("key").collect() - self.assertEqual([r.key for r in keys], list(range(100))) - - # regression test for SPARK-10417 - def test_column_iterator(self): - - def foo(): - for x in self.df.key: - break - - self.assertRaises(TypeError, foo) - - # add test for SPARK-10577 (test broadcast join hint) - def test_functions_broadcast(self): - from pyspark.sql.functions import broadcast - - df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - - # equijoin - should be converted into broadcast join - plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() - self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) - - # no join key -- should not be a broadcast join - plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() - self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) - - # planner should not crash without a join - broadcast(df1)._jdf.queryExecution().executedPlan() - - def test_generic_hints(self): - from pyspark.sql import DataFrame - - df1 = self.spark.range(10e10).toDF("id") - df2 = self.spark.range(10e10).toDF("id") - - self.assertIsInstance(df1.hint("broadcast"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", []), DataFrame) - - # Dummy rules - self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) - - plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() - self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) - - def test_sample(self): - self.assertRaisesRegexp( - TypeError, - "should be a bool, float and number", - lambda: self.spark.range(1).sample()) - - self.assertRaises( - TypeError, - lambda: self.spark.range(1).sample("a")) - - self.assertRaises( - TypeError, - lambda: self.spark.range(1).sample(seed="abc")) - - self.assertRaises( - IllegalArgumentException, - lambda: self.spark.range(1).sample(-1.0)) - - def test_toDF_with_schema_string(self): - data = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(data, 5) - - df = rdd.toDF("key: int, value: string") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) - - # different but compatible field types can be used. - df = rdd.toDF("key: string, value: string") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) - - # field names can differ. - df = rdd.toDF(" a: int, b: string ") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) - - # number of fields must match. - self.assertRaisesRegexp(Exception, "Length of object", - lambda: rdd.toDF("key: int").collect()) - - # field types mismatch will cause exception at runtime. - self.assertRaisesRegexp(Exception, "FloatType can not accept", - lambda: rdd.toDF("key: float, value: string").collect()) - - # flat schema values will be wrapped into row. - df = rdd.map(lambda row: row.key).toDF("int") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) - - # users can use DataType directly instead of data type string. - df = rdd.map(lambda row: row.key).toDF(IntegerType()) - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) - - def test_join_without_on(self): - df1 = self.spark.range(1).toDF("a") - df2 = self.spark.range(1).toDF("b") - - with self.sql_conf({"spark.sql.crossJoin.enabled": False}): - self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - actual = df1.join(df2, how="inner").collect() - expected = [Row(a=0, b=0)] - self.assertEqual(actual, expected) - - # Regression test for invalid join methods when on is None, Spark-14761 - def test_invalid_join_method(self): - df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) - df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) - self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) - - # Cartesian products require cross join syntax - def test_require_cross(self): - from pyspark.sql.functions import broadcast - - df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) - df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) - - # joins without conditions require cross join syntax - self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) - - # works with crossJoin - self.assertEqual(1, df1.crossJoin(df2).count()) - - def test_conf(self): - spark = self.spark - spark.conf.set("bogo", "sipeo") - self.assertEqual(spark.conf.get("bogo"), "sipeo") - spark.conf.set("bogo", "ta") - self.assertEqual(spark.conf.get("bogo"), "ta") - self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") - self.assertEqual(spark.conf.get("not.set", "ta"), "ta") - self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) - spark.conf.unset("bogo") - self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") - - self.assertEqual(spark.conf.get("hyukjin", None), None) - - # This returns 'STATIC' because it's the default value of - # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in - # `spark.conf.get` is unset. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") - - # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but - # `defaultValue` in `spark.conf.get` is set to None. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) - - def test_current_database(self): - spark = self.spark - with self.database("some_db"): - self.assertEquals(spark.catalog.currentDatabase(), "default") - spark.sql("CREATE DATABASE some_db") - spark.catalog.setCurrentDatabase("some_db") - self.assertEquals(spark.catalog.currentDatabase(), "some_db") - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.setCurrentDatabase("does_not_exist")) - - def test_list_databases(self): - spark = self.spark - with self.database("some_db"): - databases = [db.name for db in spark.catalog.listDatabases()] - self.assertEquals(databases, ["default"]) - spark.sql("CREATE DATABASE some_db") - databases = [db.name for db in spark.catalog.listDatabases()] - self.assertEquals(sorted(databases), ["default", "some_db"]) - - def test_list_tables(self): - from pyspark.sql.catalog import Table - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - with self.table("tab1", "some_db.tab2"): - with self.tempView("temp_tab"): - self.assertEquals(spark.catalog.listTables(), []) - self.assertEquals(spark.catalog.listTables("some_db"), []) - spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") - spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") - tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) - tablesDefault = \ - sorted(spark.catalog.listTables("default"), key=lambda t: t.name) - tablesSomeDb = \ - sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) - self.assertEquals(tables, tablesDefault) - self.assertEquals(len(tables), 2) - self.assertEquals(len(tablesSomeDb), 2) - self.assertEquals(tables[0], Table( - name="tab1", - database="default", - description=None, - tableType="MANAGED", - isTemporary=False)) - self.assertEquals(tables[1], Table( - name="temp_tab", - database=None, - description=None, - tableType="TEMPORARY", - isTemporary=True)) - self.assertEquals(tablesSomeDb[0], Table( - name="tab2", - database="some_db", - description=None, - tableType="MANAGED", - isTemporary=False)) - self.assertEquals(tablesSomeDb[1], Table( - name="temp_tab", - database=None, - description=None, - tableType="TEMPORARY", - isTemporary=True)) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listTables("does_not_exist")) - - def test_list_functions(self): - from pyspark.sql.catalog import Function - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - functions = dict((f.name, f) for f in spark.catalog.listFunctions()) - functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) - self.assertTrue(len(functions) > 200) - self.assertTrue("+" in functions) - self.assertTrue("like" in functions) - self.assertTrue("month" in functions) - self.assertTrue("to_date" in functions) - self.assertTrue("to_timestamp" in functions) - self.assertTrue("to_unix_timestamp" in functions) - self.assertTrue("current_database" in functions) - self.assertEquals(functions["+"], Function( - name="+", - description=None, - className="org.apache.spark.sql.catalyst.expressions.Add", - isTemporary=True)) - self.assertEquals(functions, functionsDefault) - - with self.function("func1", "some_db.func2"): - spark.catalog.registerFunction("temp_func", lambda x: str(x)) - spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") - spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") - newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) - newFunctionsSomeDb = \ - dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) - self.assertTrue(set(functions).issubset(set(newFunctions))) - self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) - self.assertTrue("temp_func" in newFunctions) - self.assertTrue("func1" in newFunctions) - self.assertTrue("func2" not in newFunctions) - self.assertTrue("temp_func" in newFunctionsSomeDb) - self.assertTrue("func1" not in newFunctionsSomeDb) - self.assertTrue("func2" in newFunctionsSomeDb) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listFunctions("does_not_exist")) - - def test_list_columns(self): - from pyspark.sql.catalog import Column - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - with self.table("tab1", "some_db.tab2"): - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") - spark.sql( - "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet") - columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) - columnsDefault = \ - sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) - self.assertEquals(columns, columnsDefault) - self.assertEquals(len(columns), 2) - self.assertEquals(columns[0], Column( - name="age", - description=None, - dataType="int", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertEquals(columns[1], Column( - name="name", - description=None, - dataType="string", - nullable=True, - isPartition=False, - isBucket=False)) - columns2 = \ - sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) - self.assertEquals(len(columns2), 2) - self.assertEquals(columns2[0], Column( - name="nickname", - description=None, - dataType="string", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertEquals(columns2[1], Column( - name="tolerance", - description=None, - dataType="float", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertRaisesRegexp( - AnalysisException, - "tab2", - lambda: spark.catalog.listColumns("tab2")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listColumns("does_not_exist")) - - def test_cache(self): - spark = self.spark - with self.tempView("tab1", "tab2"): - spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1") - spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2") - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - spark.catalog.cacheTable("tab1") - self.assertTrue(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - spark.catalog.cacheTable("tab2") - spark.catalog.uncacheTable("tab1") - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertTrue(spark.catalog.isCached("tab2")) - spark.catalog.clearCache() - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.isCached("does_not_exist")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.cacheTable("does_not_exist")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.uncacheTable("does_not_exist")) - - def test_read_text_file_list(self): - df = self.spark.read.text(['python/test_support/sql/text-test.txt', - 'python/test_support/sql/text-test.txt']) - count = df.count() - self.assertEquals(count, 4) - - def test_BinaryType_serialization(self): - # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 - # The empty bytearray is test for SPARK-21534. - schema = StructType([StructField('mybytes', BinaryType())]) - data = [[bytearray(b'here is my data')], - [bytearray(b'and here is some more')], - [bytearray(b'')]] - df = self.spark.createDataFrame(data, schema=schema) - df.collect() - - # test for SPARK-16542 - def test_array_types(self): - # This test need to make sure that the Scala type selected is at least - # as large as the python's types. This is necessary because python's - # array types depend on C implementation on the machine. Therefore there - # is no machine independent correspondence between python's array types - # and Scala types. - # See: https://docs.python.org/2/library/array.html - - def assertCollectSuccess(typecode, value): - row = Row(myarray=array.array(typecode, [value])) - df = self.spark.createDataFrame([row]) - self.assertEqual(df.first()["myarray"][0], value) - - # supported string types - # - # String types in python's array are "u" for Py_UNICODE and "c" for char. - # "u" will be removed in python 4, and "c" is not supported in python 3. - supported_string_types = [] - if sys.version_info[0] < 4: - supported_string_types += ['u'] - # test unicode - assertCollectSuccess('u', u'a') - if sys.version_info[0] < 3: - supported_string_types += ['c'] - # test string - assertCollectSuccess('c', 'a') - - # supported float and double - # - # Test max, min, and precision for float and double, assuming IEEE 754 - # floating-point format. - supported_fractional_types = ['f', 'd'] - assertCollectSuccess('f', ctypes.c_float(1e+38).value) - assertCollectSuccess('f', ctypes.c_float(1e-38).value) - assertCollectSuccess('f', ctypes.c_float(1.123456).value) - assertCollectSuccess('d', sys.float_info.max) - assertCollectSuccess('d', sys.float_info.min) - assertCollectSuccess('d', sys.float_info.epsilon) - - # supported signed int types - # - # The size of C types changes with implementation, we need to make sure - # that there is no overflow error on the platform running this test. - supported_signed_int_types = list( - set(_array_signed_int_typecode_ctype_mappings.keys()) - .intersection(set(_array_type_mappings.keys()))) - for t in supported_signed_int_types: - ctype = _array_signed_int_typecode_ctype_mappings[t] - max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1) - assertCollectSuccess(t, max_val - 1) - assertCollectSuccess(t, -max_val) - - # supported unsigned int types - # - # JVM does not have unsigned types. We need to be very careful to make - # sure that there is no overflow error. - supported_unsigned_int_types = list( - set(_array_unsigned_int_typecode_ctype_mappings.keys()) - .intersection(set(_array_type_mappings.keys()))) - for t in supported_unsigned_int_types: - ctype = _array_unsigned_int_typecode_ctype_mappings[t] - assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1) - - # all supported types - # - # Make sure the types tested above: - # 1. are all supported types - # 2. cover all supported types - supported_types = (supported_string_types + - supported_fractional_types + - supported_signed_int_types + - supported_unsigned_int_types) - self.assertEqual(set(supported_types), set(_array_type_mappings.keys())) - - # all unsupported types - # - # Keys in _array_type_mappings is a complete list of all supported types, - # and types not in _array_type_mappings are considered unsupported. - # `array.typecodes` are not supported in python 2. - if sys.version_info[0] < 3: - all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd']) - else: - all_types = set(array.typecodes) - unsupported_types = all_types - set(supported_types) - # test unsupported types - for t in unsupported_types: - with self.assertRaises(TypeError): - a = array.array(t) - self.spark.createDataFrame([Row(myarray=a)]).collect() - - def test_bucketed_write(self): - data = [ - (1, "foo", 3.0), (2, "foo", 5.0), - (3, "bar", -1.0), (4, "bar", 6.0), - ] - df = self.spark.createDataFrame(data, ["x", "y", "z"]) - - def count_bucketed_cols(names, table="pyspark_bucket"): - """Given a sequence of column names and a table name - query the catalog and return number o columns which are - used for bucketing - """ - cols = self.spark.catalog.listColumns(table) - num = len([c for c in cols if c.name in names and c.isBucket]) - return num - - with self.table("pyspark_bucket"): - # Test write with one bucketing column - df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x"]), 1) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write two bucketing columns - df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x", "y"]), 2) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort - df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x"]), 1) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with a list of columns - df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x", "y"]), 2) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort with a list of columns - (df.write.bucketBy(2, "x") - .sortBy(["y", "z"]) - .mode("overwrite").saveAsTable("pyspark_bucket")) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort with multiple columns - (df.write.bucketBy(2, "x") - .sortBy("y", "z") - .mode("overwrite").saveAsTable("pyspark_bucket")) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - def _to_pandas(self): - from datetime import datetime, date - schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", BooleanType()).add("d", FloatType())\ - .add("dt", DateType()).add("ts", TimestampType()) - data = [ - (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (2, "foo", True, 5.0, None, None), - (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), - (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), - ] - df = self.spark.createDataFrame(data, schema) - return df.toPandas() - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_to_pandas(self): - import numpy as np - pdf = self._to_pandas() - types = pdf.dtypes - self.assertEquals(types[0], np.int32) - self.assertEquals(types[1], np.object) - self.assertEquals(types[2], np.bool) - self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], np.object) # datetime.date - self.assertEquals(types[5], 'datetime64[ns]') - - @unittest.skipIf(_have_pandas, "Required Pandas was found.") - def test_to_pandas_required_pandas_not_found(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): - self._to_pandas() - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_to_pandas_avoid_astype(self): - import numpy as np - schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", IntegerType()) - data = [(1, "foo", 16777220), (None, "bar", None)] - df = self.spark.createDataFrame(data, schema) - types = df.toPandas().dtypes - self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. - self.assertEquals(types[1], np.object) - self.assertEquals(types[2], np.float64) - - def test_create_dataframe_from_array_of_long(self): - import array - data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] - df = self.spark.createDataFrame(data) - self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_create_dataframe_from_pandas_with_timestamp(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]})[["d", "ts"]] - # test types are inferred correctly without specifying schema - df = self.spark.createDataFrame(pdf) - self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) - self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - # test with schema will accept pdf as input - df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") - self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) - self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - - @unittest.skipIf(_have_pandas, "Required Pandas was found.") - def test_create_dataframe_required_pandas_not_found(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp( - ImportError, - "(Pandas >= .* must be installed|No module named '?pandas'?)"): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) - self.spark.createDataFrame(pdf) - - # Regression test for SPARK-23360 - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_create_dateframe_from_pandas_with_dst(self): - import pandas as pd - from datetime import datetime - - pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) - - orig_env_tz = os.environ.get('TZ', None) - try: - tz = 'America/Los_Angeles' - os.environ['TZ'] = tz - time.tzset() - with self.sql_conf({'spark.sql.session.timeZone': tz}): - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) - finally: - del os.environ['TZ'] - if orig_env_tz is not None: - os.environ['TZ'] = orig_env_tz - time.tzset() - - def test_sort_with_nulls_order(self): - from pyspark.sql import functions - - df = self.spark.createDataFrame( - [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - self.assertEquals( - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) - self.assertEquals( - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) - self.assertEquals( - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) - self.assertEquals( - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) - - def test_json_sampling_ratio(self): - rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ - .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) - schema = self.spark.read.option('inferSchema', True) \ - .option('samplingRatio', 0.5) \ - .json(rdd).schema - self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) - - def test_csv_sampling_ratio(self): - rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ - .map(lambda x: '0.1' if x == 1 else str(x)) - schema = self.spark.read.option('inferSchema', True)\ - .csv(rdd, samplingRatio=0.5).schema - self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) - - def test_checking_csv_header(self): - path = tempfile.mkdtemp() - shutil.rmtree(path) - try: - self.spark.createDataFrame([[1, 1000], [2000, 2]])\ - .toDF('f1', 'f2').write.option("header", "true").csv(path) - schema = StructType([ - StructField('f2', IntegerType(), nullable=True), - StructField('f1', IntegerType(), nullable=True)]) - df = self.spark.read.option('header', 'true').schema(schema)\ - .csv(path, enforceSchema=False) - self.assertRaisesRegexp( - Exception, - "CSV header does not conform to the schema", - lambda: df.collect()) - finally: - shutil.rmtree(path) - - def test_ignore_column_of_all_nulls(self): - path = tempfile.mkdtemp() - shutil.rmtree(path) - try: - df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], - ["""{"a":null, "b":null, "c":"string"}"""], - ["""{"a":null, "b":null, "c":null}"""]]) - df.write.text(path) - schema = StructType([ - StructField('b', LongType(), nullable=True), - StructField('c', StringType(), nullable=True)]) - readback = self.spark.read.json(path, dropFieldIfAllNull=True) - self.assertEquals(readback.schema, schema) - finally: - shutil.rmtree(path) - - # SPARK-24721 - @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf(self): - from pyspark.sql.functions import udf, lit, col - - path = tempfile.mkdtemp() - shutil.rmtree(path) - - try: - self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') - datasource_df = self.spark.read \ - .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load().toDF('i') - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load().toDF('i', 'j') - - c1 = udf(lambda x: x + 1, 'int')(lit(1)) - c2 = udf(lambda x: x + 1, 'int')(col('i')) - - f1 = udf(lambda x: False, 'boolean')(lit(1)) - f2 = udf(lambda x: False, 'boolean')(col('i')) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c1) - expected = df.withColumn('c', lit(2)) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c2) - expected = df.withColumn('c', col('i') + 1) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [f1, f2]: - result = df.filter(f) - self.assertEquals(0, result.count()) - finally: - shutil.rmtree(path) - - def test_repr_behaviors(self): - import re - pattern = re.compile(r'^ *\|', re.MULTILINE) - df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - - # test when eager evaluation is enabled and _repr_html_ will not be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): - expected1 = """+-----+-----+ - || key|value| - |+-----+-----+ - || 1| 1| - ||22222|22222| - |+-----+-----+ - |""" - self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - expected2 = """+---+-----+ - ||key|value| - |+---+-----+ - || 1| 1| - ||222| 222| - |+---+-----+ - |""" - self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - expected3 = """+---+-----+ - ||key|value| - |+---+-----+ - || 1| 1| - |+---+-----+ - |only showing top 1 row - |""" - self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) - - # test when eager evaluation is enabled and _repr_html_ will be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): - expected1 = """ - | - | - | - |
    keyvalue
    11
    2222222222
    - |""" - self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - expected2 = """ - | - | - | - |
    keyvalue
    11
    222222
    - |""" - self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - expected3 = """ - | - | - |
    keyvalue
    11
    - |only showing top 1 row - |""" - self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) - - # test when eager evaluation is disabled and _repr_html_ will be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): - expected = "DataFrame[key: bigint, value: string]" - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - - # SPARK-25591 - def test_same_accumulator_in_udfs(self): - from pyspark.sql.functions import udf - - data_schema = StructType([StructField("a", IntegerType(), True), - StructField("b", IntegerType(), True)]) - data = self.spark.createDataFrame([[1, 2]], schema=data_schema) - - test_accum = self.sc.accumulator(0) - - def first_udf(x): - test_accum.add(1) - return x - - def second_udf(x): - test_accum.add(100) - return x - - func_udf = udf(first_udf, IntegerType()) - func_udf2 = udf(second_udf, IntegerType()) - data = data.withColumn("out1", func_udf(data["a"])) - data = data.withColumn("out2", func_udf2(data["b"])) - data.collect() - self.assertEqual(test_accum.value, 101) - - -class HiveSparkSubmitTests(SparkSubmitTests): - - @classmethod - def setUpClass(cls): - # get a SparkContext to check for availability of Hive - sc = SparkContext('local[4]', cls.__name__) - cls.hive_available = True - try: - sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.hive_available = False - except TypeError: - cls.hive_available = False - finally: - # we don't need this SparkContext for the test - sc.stop() - - def setUp(self): - super(HiveSparkSubmitTests, self).setUp() - if not self.hive_available: - self.skipTest("Hive is not available.") - - def test_hivecontext(self): - # This test checks that HiveContext is using Hive metastore (SPARK-16224). - # It sets a metastore url and checks if there is a derby dir created by - # Hive metastore. If this derby dir exists, HiveContext is using - # Hive metastore. - metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db") - metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true" - hive_site_dir = os.path.join(self.programDir, "conf") - hive_site_file = self.createTempFile("hive-site.xml", (""" - | - | - | javax.jdo.option.ConnectionURL - | %s - | - | - """ % metastore_URL).lstrip(), "conf") - script = self.createTempFile("test.py", """ - |import os - | - |from pyspark.conf import SparkConf - |from pyspark.context import SparkContext - |from pyspark.sql import HiveContext - | - |conf = SparkConf() - |sc = SparkContext(conf=conf) - |hive_context = HiveContext(sc) - |print(hive_context.sql("show databases").collect()) - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("default", out.decode('utf-8')) - self.assertTrue(os.path.exists(metastore_path)) - - -class SQLTests2(ReusedSQLTestCase): - - # We can't include this test into SQLTests because we will stop class's SparkContext and cause - # other tests failed. - def test_sparksession_with_stopped_sparkcontext(self): - self.sc.stop() - sc = SparkContext('local[4]', self.sc.appName) - spark = SparkSession.builder.getOrCreate() - try: - df = spark.createDataFrame([(1, 2)], ["c", "c"]) - df.collect() - finally: - spark.stop() - sc.stop() - - -class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): - # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is - # static and immutable. This can't be set or unset, for example, via `spark.conf`. - - @classmethod - def setUpClass(cls): - import glob - from pyspark.find_spark_home import _find_spark_home - - SPARK_HOME = _find_spark_home() - filename_pattern = ( - "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" - "TestQueryExecutionListener.class") - cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) - - if cls.has_listener: - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - - def setUp(self): - if not self.has_listener: - raise self.skipTest( - "'org.apache.spark.sql.TestQueryExecutionListener' is not " - "available. Will skip the related tests.") - - @classmethod - def tearDownClass(cls): - if hasattr(cls, "spark"): - cls.spark.stop() - - def tearDown(self): - self.spark._jvm.OnSuccessCall.clear() - - def test_query_execution_listener_on_collect(self): - self.assertFalse( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should not be called before 'collect'") - self.spark.sql("SELECT * FROM range(1)").collect() - self.assertTrue( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should be called after 'collect'") - - @unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) - def test_query_execution_listener_on_collect_with_arrow(self): - with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): - self.assertFalse( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should not be " - "called before 'toPandas'") - self.spark.sql("SELECT * FROM range(1)").toPandas() - self.assertTrue( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should be called after 'toPandas'") - - -class SparkExtensionsTest(unittest.TestCase): - # These tests are separate because it uses 'spark.sql.extensions' which is - # static and immutable. This can't be set or unset, for example, via `spark.conf`. - - @classmethod - def setUpClass(cls): - import glob - from pyspark.find_spark_home import _find_spark_home - - SPARK_HOME = _find_spark_home() - filename_pattern = ( - "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" - "SparkSessionExtensionSuite.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( - "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " - "available. Will skip the related tests.") - - # Note that 'spark.sql.extensions' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.extensions", - "org.apache.spark.sql.MyExtensions") \ - .getOrCreate() - - @classmethod - def tearDownClass(cls): - cls.spark.stop() - - def test_use_custom_class_for_extensions(self): - self.assertTrue( - self.spark._jsparkSession.sessionState().planner().strategies().contains( - self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), - "MySparkStrategy not found in active planner strategies") - self.assertTrue( - self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( - self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), - "MyRule not found in extended resolution rules") - - -class SparkSessionTests(PySparkTestCase): - - # This test is separate because it's closely related with session's start and stop. - # See SPARK-23228. - def test_set_jvm_default_session(self): - spark = SparkSession.builder.getOrCreate() - try: - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - finally: - spark.stop() - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) - - def test_jvm_default_session_already_set(self): - # Here, we assume there is the default session already set in JVM. - jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) - self.sc._jvm.SparkSession.setDefaultSession(jsession) - - spark = SparkSession.builder.getOrCreate() - try: - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - # The session should be the same with the exiting one. - self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) - finally: - spark.stop() - - -class SparkSessionTests2(unittest.TestCase): - - def test_active_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - activeSession = SparkSession.getActiveSession() - df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) - self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) - finally: - spark.stop() - - def test_get_active_session_when_no_active_session(self): - active = SparkSession.getActiveSession() - self.assertEqual(active, None) - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - active = SparkSession.getActiveSession() - self.assertEqual(active, spark) - spark.stop() - active = SparkSession.getActiveSession() - self.assertEqual(active, None) - - def test_SparkSession(self): - spark = SparkSession.builder \ - .master("local") \ - .config("some-config", "v2") \ - .getOrCreate() - try: - self.assertEqual(spark.conf.get("some-config"), "v2") - self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") - self.assertEqual(spark.version, spark.sparkContext.version) - spark.sql("CREATE DATABASE test_db") - spark.catalog.setCurrentDatabase("test_db") - self.assertEqual(spark.catalog.currentDatabase(), "test_db") - spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") - self.assertEqual(spark.table("table1").columns, ['name', 'age']) - self.assertEqual(spark.range(3).count(), 3) - finally: - spark.stop() - - def test_global_default_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - self.assertEqual(SparkSession.builder.getOrCreate(), spark) - finally: - spark.stop() - - def test_default_and_active_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - activeSession = spark._jvm.SparkSession.getActiveSession() - defaultSession = spark._jvm.SparkSession.getDefaultSession() - try: - self.assertEqual(activeSession, defaultSession) - finally: - spark.stop() - - def test_config_option_propagated_to_existing_session(self): - session1 = SparkSession.builder \ - .master("local") \ - .config("spark-config1", "a") \ - .getOrCreate() - self.assertEqual(session1.conf.get("spark-config1"), "a") - session2 = SparkSession.builder \ - .config("spark-config1", "b") \ - .getOrCreate() - try: - self.assertEqual(session1, session2) - self.assertEqual(session1.conf.get("spark-config1"), "b") - finally: - session1.stop() - - def test_new_session(self): - session = SparkSession.builder \ - .master("local") \ - .getOrCreate() - newSession = session.newSession() - try: - self.assertNotEqual(session, newSession) - finally: - session.stop() - newSession.stop() - - def test_create_new_session_if_old_session_stopped(self): - session = SparkSession.builder \ - .master("local") \ - .getOrCreate() - session.stop() - newSession = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - self.assertNotEqual(session, newSession) - finally: - newSession.stop() - - def test_active_session_with_None_and_not_None_context(self): - from pyspark.context import SparkContext - from pyspark.conf import SparkConf - sc = None - session = None - try: - sc = SparkContext._active_spark_context - self.assertEqual(sc, None) - activeSession = SparkSession.getActiveSession() - self.assertEqual(activeSession, None) - sparkConf = SparkConf() - sc = SparkContext.getOrCreate(sparkConf) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertFalse(activeSession.isDefined()) - session = SparkSession(sc) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertTrue(activeSession.isDefined()) - activeSession2 = SparkSession.getActiveSession() - self.assertNotEqual(activeSession2, None) - finally: - if session is not None: - session.stop() - if sc is not None: - sc.stop() - - -class SparkSessionTests3(ReusedSQLTestCase): - - def test_get_active_session_after_create_dataframe(self): - session2 = None - try: - activeSession1 = SparkSession.getActiveSession() - session1 = self.spark - self.assertEqual(session1, activeSession1) - session2 = self.spark.newSession() - activeSession2 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession2) - self.assertNotEqual(session2, activeSession2) - session2.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession3 = SparkSession.getActiveSession() - self.assertEqual(session2, activeSession3) - session1.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession4 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession4) - finally: - if session2 is not None: - session2.stop() - - -class UDFInitializationTests(unittest.TestCase): - def tearDown(self): - if SparkSession._instantiatedSession is not None: - SparkSession._instantiatedSession.stop() - - if SparkContext._active_spark_context is not None: - SparkContext._active_spark_context.stop() - - def test_udf_init_shouldnt_initialize_context(self): - from pyspark.sql.functions import UserDefinedFunction - - UserDefinedFunction(lambda x: x, StringType()) - - self.assertIsNone( - SparkContext._active_spark_context, - "SparkContext shouldn't be initialized when UserDefinedFunction is created." - ) - self.assertIsNone( - SparkSession._instantiatedSession, - "SparkSession shouldn't be initialized when UserDefinedFunction is created." - ) - - -class HiveContextSQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - cls.hive_available = True - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.hive_available = False - except TypeError: - cls.hive_available = False - os.unlink(cls.tempdir.name) - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_save_and_load_table(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) - actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE externalJsonTable") - - df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.createExternalTable("externalJsonTable", source="json", - schema=schema, path=tmpPath, - noUse="this options will not be used") - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.select("value").collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - - defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) - - def test_window_functions(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - w = Window.partitionBy("value").orderBy("key") - from pyspark.sql import functions as F - sel = df.select(df.value, df.key, - F.max("key").over(w.rowsBetween(0, 1)), - F.min("key").over(w.rowsBetween(0, 1)), - F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.row_number().over(w), - F.rank().over(w), - F.dense_rank().over(w), - F.ntile(2).over(w)) - rs = sorted(sel.collect()) - expected = [ - ("1", 1, 1, 1, 1, 1, 1, 1, 1), - ("2", 1, 1, 1, 3, 1, 1, 1, 1), - ("2", 1, 2, 1, 3, 2, 1, 1, 1), - ("2", 2, 2, 2, 3, 3, 3, 2, 2) - ] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_window_functions_without_partitionBy(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - w = Window.orderBy("key", df.value) - from pyspark.sql import functions as F - sel = df.select(df.value, df.key, - F.max("key").over(w.rowsBetween(0, 1)), - F.min("key").over(w.rowsBetween(0, 1)), - F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.row_number().over(w), - F.rank().over(w), - F.dense_rank().over(w), - F.ntile(2).over(w)) - rs = sorted(sel.collect()) - expected = [ - ("1", 1, 1, 1, 4, 1, 1, 1, 1), - ("2", 1, 1, 1, 4, 2, 2, 2, 1), - ("2", 1, 2, 1, 4, 3, 2, 2, 2), - ("2", 2, 2, 2, 4, 4, 4, 3, 2) - ] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_window_functions_cumulative_sum(self): - df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) - from pyspark.sql import functions as F - - # Test cumulative sum - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) - rs = sorted(sel.collect()) - expected = [("one", 1), ("two", 3)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) - rs = sorted(sel.collect()) - expected = [("one", 1), ("two", 3)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow - frame_end = Window.unboundedFollowing + 1 - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) - rs = sorted(sel.collect()) - expected = [("one", 3), ("two", 2)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_collect_functions(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql import functions - - self.assertEqual( - sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), - [1, 2]) - self.assertEqual( - sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), - [1, 1, 1, 2]) - self.assertEqual( - sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), - ["1", "2"]) - self.assertEqual( - sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), - ["1", "2", "2", "2"]) - - def test_limit_and_take(self): - df = self.spark.range(1, 1000, numPartitions=10) - - def assert_runs_only_one_job_stage_and_task(job_group_name, f): - tracker = self.sc.statusTracker() - self.sc.setJobGroup(job_group_name, description="") - f() - jobs = tracker.getJobIdsForGroup(job_group_name) - self.assertEqual(1, len(jobs)) - stages = tracker.getJobInfo(jobs[0]).stageIds - self.assertEqual(1, len(stages)) - self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) - - # Regression test for SPARK-10731: take should delegate to Scala implementation - assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) - # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) - assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) - - def test_datetime_functions(self): - from pyspark.sql import functions - from datetime import date, datetime - df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") - parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() - self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) - - @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") - def test_unbounded_frames(self): - from unittest.mock import patch - from pyspark.sql import functions as F - from pyspark.sql import window - import importlib - - df = self.spark.range(0, 3) - - def rows_frame_match(): - return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( - F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize)) - ).columns[0] - - def range_frame_match(): - return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( - F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize)) - ).columns[0] - - with patch("sys.maxsize", 2 ** 31 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - with patch("sys.maxsize", 2 ** 63 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - with patch("sys.maxsize", 2 ** 127 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - importlib.reload(window) - - -class DataTypeVerificationTests(unittest.TestCase): - - def test_verify_type_exception_msg(self): - self.assertRaisesRegexp( - ValueError, - "test_name", - lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) - - schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) - self.assertRaisesRegexp( - TypeError, - "field b in field a", - lambda: _make_type_verifier(schema)([["data"]])) - - def test_verify_type_ok_nullable(self): - obj = None - types = [IntegerType(), FloatType(), StringType(), StructType([])] - for data_type in types: - try: - _make_type_verifier(data_type, nullable=True)(obj) - except Exception: - self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) - - def test_verify_type_not_nullable(self): - import array - import datetime - import decimal - - schema = StructType([ - StructField('s', StringType(), nullable=False), - StructField('i', IntegerType(), nullable=True)]) - - class MyObj: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - # obj, data_type - success_spec = [ - # String - ("", StringType()), - (u"", StringType()), - (1, StringType()), - (1.0, StringType()), - ([], StringType()), - ({}, StringType()), - - # UDT - (ExamplePoint(1.0, 2.0), ExamplePointUDT()), - - # Boolean - (True, BooleanType()), - - # Byte - (-(2**7), ByteType()), - (2**7 - 1, ByteType()), - - # Short - (-(2**15), ShortType()), - (2**15 - 1, ShortType()), - - # Integer - (-(2**31), IntegerType()), - (2**31 - 1, IntegerType()), - - # Long - (2**64, LongType()), - - # Float & Double - (1.0, FloatType()), - (1.0, DoubleType()), - - # Decimal - (decimal.Decimal("1.0"), DecimalType()), - - # Binary - (bytearray([1, 2]), BinaryType()), - - # Date/Timestamp - (datetime.date(2000, 1, 2), DateType()), - (datetime.datetime(2000, 1, 2, 3, 4), DateType()), - (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), - - # Array - ([], ArrayType(IntegerType())), - (["1", None], ArrayType(StringType(), containsNull=True)), - ([1, 2], ArrayType(IntegerType())), - ((1, 2), ArrayType(IntegerType())), - (array.array('h', [1, 2]), ArrayType(IntegerType())), - - # Map - ({}, MapType(StringType(), IntegerType())), - ({"a": 1}, MapType(StringType(), IntegerType())), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), - - # Struct - ({"s": "a", "i": 1}, schema), - ({"s": "a", "i": None}, schema), - ({"s": "a"}, schema), - ({"s": "a", "f": 1.0}, schema), - (Row(s="a", i=1), schema), - (Row(s="a", i=None), schema), - (Row(s="a", i=1, f=1.0), schema), - (["a", 1], schema), - (["a", None], schema), - (("a", 1), schema), - (MyObj(s="a", i=1), schema), - (MyObj(s="a", i=None), schema), - (MyObj(s="a"), schema), - ] - - # obj, data_type, exception class - failure_spec = [ - # String (match anything but None) - (None, StringType(), ValueError), - - # UDT - (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), - - # Boolean - (1, BooleanType(), TypeError), - ("True", BooleanType(), TypeError), - ([1], BooleanType(), TypeError), - - # Byte - (-(2**7) - 1, ByteType(), ValueError), - (2**7, ByteType(), ValueError), - ("1", ByteType(), TypeError), - (1.0, ByteType(), TypeError), - - # Short - (-(2**15) - 1, ShortType(), ValueError), - (2**15, ShortType(), ValueError), - - # Integer - (-(2**31) - 1, IntegerType(), ValueError), - (2**31, IntegerType(), ValueError), - - # Float & Double - (1, FloatType(), TypeError), - (1, DoubleType(), TypeError), - - # Decimal - (1.0, DecimalType(), TypeError), - (1, DecimalType(), TypeError), - ("1.0", DecimalType(), TypeError), - - # Binary - (1, BinaryType(), TypeError), - - # Date/Timestamp - ("2000-01-02", DateType(), TypeError), - (946811040, TimestampType(), TypeError), - - # Array - (["1", None], ArrayType(StringType(), containsNull=False), ValueError), - ([1, "2"], ArrayType(IntegerType()), TypeError), - - # Map - ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), - ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), - ValueError), - - # Struct - ({"s": "a", "i": "1"}, schema, TypeError), - (Row(s="a"), schema, ValueError), # Row can't have missing field - (Row(s="a", i="1"), schema, TypeError), - (["a"], schema, ValueError), - (["a", "1"], schema, TypeError), - (MyObj(s="a", i="1"), schema, TypeError), - (MyObj(s=None, i="1"), schema, ValueError), - ] - - # Check success cases - for obj, data_type in success_spec: - try: - _make_type_verifier(data_type, nullable=False)(obj) - except Exception: - self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) - - # Check failure cases - for obj, data_type, exp in failure_spec: - msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) - with self.assertRaises(exp, msg=msg): - _make_type_verifier(data_type, nullable=False)(obj) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class ArrowTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - from datetime import date, datetime - from decimal import Decimal - from distutils.version import LooseVersion - import pyarrow as pa - super(ArrowTests, cls).setUpClass() - cls.warnings_lock = threading.Lock() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.spark.conf.set("spark.sql.session.timeZone", tz) - cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") - # Disable fallback by default to easily detect the failures. - cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") - cls.schema = StructType([ - StructField("1_str_t", StringType(), True), - StructField("2_int_t", IntegerType(), True), - StructField("3_long_t", LongType(), True), - StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True), - StructField("6_decimal_t", DecimalType(38, 18), True), - StructField("7_date_t", DateType(), True), - StructField("8_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] - - # TODO: remove version check once minimum pyarrow version is 0.10.0 - if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): - cls.schema.add(StructField("9_binary_t", BinaryType(), True)) - cls.data[0] = cls.data[0] + (bytearray(b"a"),) - cls.data[1] = cls.data[1] + (bytearray(b"bb"),) - cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - super(ArrowTests, cls).tearDownClass() - - def create_pandas_data_frame(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - return pd.DataFrame(data=data_dict) - - def test_toPandas_fallback_enabled(self): - import pandas as pd - - with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) - with QuietTest(self.sc): - with self.warnings_lock: - with warnings.catch_warnings(record=True) as warns: - # we want the warnings to appear even if this test is run from a subclass - warnings.simplefilter("always") - pdf = df.toPandas() - # Catch and check the last UserWarning. - user_warns = [ - warn.message for warn in warns if isinstance(warn.message, UserWarning)] - self.assertTrue(len(user_warns) > 0) - self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) - self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) - - def test_toPandas_fallback_disabled(self): - from distutils.version import LooseVersion - import pyarrow as pa - - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - with QuietTest(self.sc): - with self.warnings_lock: - with self.assertRaisesRegexp(Exception, 'Unsupported type'): - df.toPandas() - - # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - schema = StructType([StructField("binary", BinaryType(), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): - df.toPandas() - - def test_null_conversion(self): - df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + - self.data) - pdf = df_null.toPandas() - null_counts = pdf.isnull().sum().tolist() - self.assertTrue(all([c == 1 for c in null_counts])) - - def _toPandas_arrow_toggle(self, df): - with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): - pdf = df.toPandas() - - pdf_arrow = df.toPandas() - - return pdf, pdf_arrow - - @unittest.skip("This test flakes depending on system timezone") - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - expected = self.create_pandas_data_frame() - self.assertPandasEqual(expected, pdf) - self.assertPandasEqual(expected, pdf_arrow) - - @unittest.skip("This test flakes depending on system timezone") - def test_toPandas_respect_session_timezone(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_ny, pdf_ny) - - self.assertFalse(pdf_ny.equals(pdf_la)) - - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - pdf_la_corrected = pdf_la.copy() - for field in self.schema: - if isinstance(field.dataType, TimestampType): - pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( - pdf_la_corrected[field.name], timezone) - self.assertPandasEqual(pdf_ny, pdf_la_corrected) - - @unittest.skip("This test flakes depending on system timezone") - def test_pandas_round_trip(self): - pdf = self.create_pandas_data_frame() - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas() - self.assertPandasEqual(pdf_arrow, pdf) - - def test_filtered_frame(self): - df = self.spark.range(3).toDF("i") - pdf = df.filter("i < 0").toPandas() - self.assertEqual(len(pdf.columns), 1) - self.assertEqual(pdf.columns[0], "i") - self.assertTrue(pdf.empty) - - def _createDataFrame_toggle(self, pdf, schema=None): - with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): - df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - - df_arrow = self.spark.createDataFrame(pdf, schema=schema) - - return df_no_arrow, df_arrow - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_toggle(self): - pdf = self.create_pandas_data_frame() - df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) - self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_respect_session_timezone(self): - from datetime import timedelta - pdf = self.create_pandas_data_frame() - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) - result_ny = df_no_arrow_ny.collect() - result_arrow_ny = df_arrow_ny.collect() - self.assertEqual(result_ny, result_arrow_ny) - - self.assertNotEqual(result_ny, result_la) - - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v - for k, v in row.asDict().items()}) - for row in result_la] - self.assertEqual(result_ny, result_la_corrected) - - def test_createDataFrame_with_schema(self): - pdf = self.create_pandas_data_frame() - df = self.spark.createDataFrame(pdf, schema=self.schema) - self.assertEquals(self.schema, df.schema) - pdf_arrow = df.toPandas() - self.assertPandasEqual(pdf_arrow, pdf) - - def test_createDataFrame_with_incorrect_schema(self): - pdf = self.create_pandas_data_frame() - fields = list(self.schema) - fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp - wrong_schema = StructType(fields) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): - self.spark.createDataFrame(pdf, schema=wrong_schema) - - def test_createDataFrame_with_names(self): - pdf = self.create_pandas_data_frame() - new_names = list(map(str, range(len(self.schema.fieldNames())))) - # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list(new_names)) - self.assertEquals(df.schema.fieldNames(), new_names) - # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) - self.assertEquals(df.schema.fieldNames(), new_names) - - def test_createDataFrame_column_name_encoding(self): - import pandas as pd - pdf = pd.DataFrame({u'a': [1]}) - columns = self.spark.createDataFrame(pdf).columns - self.assertTrue(isinstance(columns[0], str)) - self.assertEquals(columns[0], 'a') - columns = self.spark.createDataFrame(pdf, [u'b']).columns - self.assertTrue(isinstance(columns[0], str)) - self.assertEquals(columns[0], 'b') - - def test_createDataFrame_with_single_data_type(self): - import pandas as pd - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): - self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") - - def test_createDataFrame_does_not_modify_input(self): - import pandas as pd - # Some series get converted for Spark to consume, this makes sure input is unchanged - pdf = self.create_pandas_data_frame() - # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) - # Integers with nulls will get NaNs filled with 0 and will be casted - pdf.ix[1, '2_int_t'] = None - pdf_copy = pdf.copy(deep=True) - self.spark.createDataFrame(pdf, schema=self.schema) - self.assertTrue(pdf.equals(pdf_copy)) - - def test_schema_conversion_roundtrip(self): - from pyspark.sql.types import from_arrow_schema, to_arrow_schema - arrow_schema = to_arrow_schema(self.schema) - schema_rt = from_arrow_schema(arrow_schema) - self.assertEquals(self.schema, schema_rt) - - def test_createDataFrame_with_array_type(self): - import pandas as pd - pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) - df, df_arrow = self._createDataFrame_toggle(pdf) - result = df.collect() - result_arrow = df_arrow.collect() - expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] - for r in range(len(expected)): - for e in range(len(expected[r])): - self.assertTrue(expected[r][e] == result_arrow[r][e] and - result[r][e] == result_arrow[r][e]) - - def test_toPandas_with_array_type(self): - expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] - array_schema = StructType([StructField("a", ArrayType(IntegerType())), - StructField("b", ArrayType(StringType()))]) - df = self.spark.createDataFrame(expected, schema=array_schema) - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] - result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] - for r in range(len(expected)): - for e in range(len(expected[r])): - self.assertTrue(expected[r][e] == result_arrow[r][e] and - result[r][e] == result_arrow[r][e]) - - def test_createDataFrame_with_int_col_names(self): - import numpy as np - import pandas as pd - pdf = pd.DataFrame(np.random.rand(4, 2)) - df, df_arrow = self._createDataFrame_toggle(pdf) - pdf_col_names = [str(c) for c in pdf.columns] - self.assertEqual(pdf_col_names, df.columns) - self.assertEqual(pdf_col_names, df_arrow.columns) - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_fallback_enabled(self): - import pandas as pd - - with QuietTest(self.sc): - with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): - with warnings.catch_warnings(record=True) as warns: - # we want the warnings to appear even if this test is run from a subclass - warnings.simplefilter("always") - df = self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") - # Catch and check the last UserWarning. - user_warns = [ - warn.message for warn in warns if isinstance(warn.message, UserWarning)] - self.assertTrue(len(user_warns) > 0) - self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) - self.assertEqual(df.collect(), [Row(a={u'a': 1})]) - - def test_createDataFrame_fallback_disabled(self): - from distutils.version import LooseVersion - import pandas as pd - import pyarrow as pa - - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, 'Unsupported type'): - self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") - - # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): - self.spark.createDataFrame( - pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") - - # Regression test for SPARK-23314 - @unittest.skip("This test flakes depending on system timezone") - def test_timestamp_dst(self): - import pandas as pd - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - pdf = pd.DataFrame({'time': dt}) - - df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - df_from_pandas = self.spark.createDataFrame(pdf) - - self.assertPandasEqual(pdf, df_from_python.toPandas()) - self.assertPandasEqual(pdf, df_from_pandas.toPandas()) - - -class EncryptionArrowTests(ArrowTests): - - @classmethod - def conf(cls): - return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class PandasUDFTests(ReusedSQLTestCase): - - def test_pandas_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - - udf = pandas_udf(lambda x: x, DoubleType()) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - def test_pandas_udf_decorator(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - from pyspark.sql.types import StructType, StructField, DoubleType - - @pandas_udf(DoubleType()) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - @pandas_udf(returnType=DoubleType()) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - schema = StructType([StructField("v", DoubleType())]) - - @pandas_udf(schema, PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf('v double', PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - def test_udf_wrong_arg(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaises(ParseException): - @pandas_udf('blah') - def foo(x): - return x - with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): - @pandas_udf(functionType=PandasUDFType.SCALAR) - def foo(x): - return x - with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): - @pandas_udf('double', 100) - def foo(x): - return x - - with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): - pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) - with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): - @pandas_udf(LongType(), PandasUDFType.SCALAR) - def zero_with_type(): - return 1 - - with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) - def foo(df): - return df - with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) - def foo(df): - return df - with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v, w): - return k - - def test_stopiteration_in_udf(self): - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - def foofoo(x, y): - raise StopIteration() - - exc_message = "Caught StopIteration thrown from user's code; failing the task" - df = self.spark.range(0, 100) - - # plain udf (test for SPARK-23754) - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.withColumn('v', udf(foo)('id')).collect - ) - - # pandas scalar udf - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.withColumn( - 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') - ).collect - ) - - # pandas grouped map - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').apply( - pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect - ) - - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').apply( - pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect - ) - - # pandas grouped agg - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').agg( - pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') - ).collect - ) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class ScalarPandasUDFTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - - @property - def nondeterministic_vectorized_udf(self): - from pyspark.sql.functions import pandas_udf - - @pandas_udf('double') - def random_udf(v): - import pandas as pd - import numpy as np - return pd.Series(np.random.random(len(v))) - random_udf = random_udf.asNondeterministic() - return random_udf - - def test_pandas_udf_tokenize(self): - from pyspark.sql.functions import pandas_udf - tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), - ArrayType(StringType())) - self.assertEqual(tokenize.returnType, ArrayType(StringType())) - df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) - result = df.select(tokenize("vals").alias("hi")) - self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) - - def test_pandas_udf_nested_arrays(self): - from pyspark.sql.functions import pandas_udf - tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), - ArrayType(ArrayType(StringType()))) - self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) - df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) - result = df.select(tokenize("vals").alias("hi")) - self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) - - def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col, array - df = self.spark.range(10).select( - col('id').cast('string').alias('str'), - col('id').cast('int').alias('int'), - col('id').alias('long'), - col('id').cast('float').alias('float'), - col('id').cast('double').alias('double'), - col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool'), - array(col('id')).alias('array_long')) - f = lambda x: x - str_f = pandas_udf(f, StringType()) - int_f = pandas_udf(f, IntegerType()) - long_f = pandas_udf(f, LongType()) - float_f = pandas_udf(f, FloatType()) - double_f = pandas_udf(f, DoubleType()) - decimal_f = pandas_udf(f, DecimalType()) - bool_f = pandas_udf(f, BooleanType()) - array_long_f = pandas_udf(f, ArrayType(LongType())) - res = df.select(str_f(col('str')), int_f(col('int')), - long_f(col('long')), float_f(col('float')), - double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool')), array_long_f('array_long')) - self.assertEquals(df.collect(), res.collect()) - - def test_register_nondeterministic_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - import random - random_pandas_udf = pandas_udf( - lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() - self.assertEqual(random_pandas_udf.deterministic, False) - self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - nondeterministic_pandas_udf = self.spark.catalog.registerFunction( - "randomPandasUDF", random_pandas_udf) - self.assertEqual(nondeterministic_pandas_udf.deterministic, False) - self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() - self.assertEqual(row[0], 7) - - def test_vectorized_udf_null_boolean(self): - from pyspark.sql.functions import pandas_udf, col - data = [(True,), (True,), (None,), (False,)] - schema = StructType().add("bool", BooleanType()) - df = self.spark.createDataFrame(data, schema) - bool_f = pandas_udf(lambda x: x, BooleanType()) - res = df.select(bool_f(col('bool'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_byte(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("byte", ByteType()) - df = self.spark.createDataFrame(data, schema) - byte_f = pandas_udf(lambda x: x, ByteType()) - res = df.select(byte_f(col('byte'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_short(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("short", ShortType()) - df = self.spark.createDataFrame(data, schema) - short_f = pandas_udf(lambda x: x, ShortType()) - res = df.select(short_f(col('short'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_int(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("int", IntegerType()) - df = self.spark.createDataFrame(data, schema) - int_f = pandas_udf(lambda x: x, IntegerType()) - res = df.select(int_f(col('int'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_long(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("long", LongType()) - df = self.spark.createDataFrame(data, schema) - long_f = pandas_udf(lambda x: x, LongType()) - res = df.select(long_f(col('long'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_float(self): - from pyspark.sql.functions import pandas_udf, col - data = [(3.0,), (5.0,), (-1.0,), (None,)] - schema = StructType().add("float", FloatType()) - df = self.spark.createDataFrame(data, schema) - float_f = pandas_udf(lambda x: x, FloatType()) - res = df.select(float_f(col('float'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_double(self): - from pyspark.sql.functions import pandas_udf, col - data = [(3.0,), (5.0,), (-1.0,), (None,)] - schema = StructType().add("double", DoubleType()) - df = self.spark.createDataFrame(data, schema) - double_f = pandas_udf(lambda x: x, DoubleType()) - res = df.select(double_f(col('double'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_decimal(self): - from decimal import Decimal - from pyspark.sql.functions import pandas_udf, col - data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] - schema = StructType().add("decimal", DecimalType(38, 18)) - df = self.spark.createDataFrame(data, schema) - decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) - res = df.select(decimal_f(col('decimal'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_string(self): - from pyspark.sql.functions import pandas_udf, col - data = [("foo",), (None,), ("bar",), ("bar",)] - schema = StructType().add("str", StringType()) - df = self.spark.createDataFrame(data, schema) - str_f = pandas_udf(lambda x: x, StringType()) - res = df.select(str_f(col('str'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_string_in_udf(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) - actual = df.select(str_f(col('id'))) - expected = df.select(col('id').cast('string')) - self.assertEquals(expected.collect(), actual.collect()) - - def test_vectorized_udf_datatype_string(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10).select( - col('id').cast('string').alias('str'), - col('id').cast('int').alias('int'), - col('id').alias('long'), - col('id').cast('float').alias('float'), - col('id').cast('double').alias('double'), - col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) - f = lambda x: x - str_f = pandas_udf(f, 'string') - int_f = pandas_udf(f, 'integer') - long_f = pandas_udf(f, 'long') - float_f = pandas_udf(f, 'float') - double_f = pandas_udf(f, 'double') - decimal_f = pandas_udf(f, 'decimal(38, 18)') - bool_f = pandas_udf(f, 'boolean') - res = df.select(str_f(col('str')), int_f(col('int')), - long_f(col('long')), float_f(col('float')), - double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_binary(self): - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, col - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - else: - data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] - schema = StructType().add("binary", BinaryType()) - df = self.spark.createDataFrame(data, schema) - str_f = pandas_udf(lambda x: x, BinaryType()) - res = df.select(str_f(col('binary'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_array_type(self): - from pyspark.sql.functions import pandas_udf, col - data = [([1, 2],), ([3, 4],)] - array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) - df = self.spark.createDataFrame(data, schema=array_schema) - array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) - result = df.select(array_f(col('array'))) - self.assertEquals(df.collect(), result.collect()) - - def test_vectorized_udf_null_array(self): - from pyspark.sql.functions import pandas_udf, col - data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] - array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) - df = self.spark.createDataFrame(data, schema=array_schema) - array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) - result = df.select(array_f(col('array'))) - self.assertEquals(df.collect(), result.collect()) - - def test_vectorized_udf_complex(self): - from pyspark.sql.functions import pandas_udf, col, expr - df = self.spark.range(10).select( - col('id').cast('int').alias('a'), - col('id').cast('int').alias('b'), - col('id').cast('double').alias('c')) - add = pandas_udf(lambda x, y: x + y, IntegerType()) - power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) - mul = pandas_udf(lambda x, y: x * y, DoubleType()) - res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) - expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) - self.assertEquals(expected.collect(), res.collect()) - - def test_vectorized_udf_exception(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): - df.select(raise_exception(col('id'))).collect() - - def test_vectorized_udf_invalid_length(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Result vector from pandas_udf was not the required length'): - df.select(raise_exception(col('id'))).collect() - - def test_vectorized_udf_chained(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - f = pandas_udf(lambda x: x + 1, LongType()) - g = pandas_udf(lambda x: x - 1, LongType()) - res = df.select(g(f(col('id')))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, col - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) - - def test_vectorized_udf_return_scalar(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - f = pandas_udf(lambda x: 1.0, DoubleType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): - df.select(f(col('id'))).collect() - - def test_vectorized_udf_decorator(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - - @pandas_udf(returnType=LongType()) - def identity(x): - return x - res = df.select(identity(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_empty_partition(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) - f = pandas_udf(lambda x: x, LongType()) - res = df.select(f(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_varargs(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) - f = pandas_udf(lambda *v: v[0], LongType()) - res = df.select(f(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - - def test_vectorized_udf_dates(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import date - schema = StructType().add("idx", LongType()).add("date", DateType()) - data = [(0, date(1969, 1, 1),), - (1, date(2012, 2, 2),), - (2, None,), - (3, date(2100, 4, 4),)] - df = self.spark.createDataFrame(data, schema=schema) - - date_copy = pandas_udf(lambda t: t, returnType=DateType()) - df = df.withColumn("date_copy", date_copy(col("date"))) - - @pandas_udf(returnType=StringType()) - def check_data(idx, date, date_copy): - import pandas as pd - msgs = [] - is_equal = date.isnull() - for i in range(len(idx)): - if (is_equal[i] and data[idx[i]][1] is None) or \ - date[i] == data[idx[i]][1]: - msgs.append(None) - else: - msgs.append( - "date values are not equal (date='%s': data[%d][1]='%s')" - % (date[i], idx[i], data[idx[i]][1])) - return pd.Series(msgs) - - result = df.withColumn("check_data", - check_data(col("idx"), col("date"), col("date_copy"))).collect() - - self.assertEquals(len(data), len(result)) - for i in range(len(result)): - self.assertEquals(data[i][1], result[i][1]) # "date" col - self.assertEquals(data[i][1], result[i][2]) # "date_copy" col - self.assertIsNone(result[i][3]) # "check_data" col - - @unittest.skip("This test flakes depending on system timezone") - def test_vectorized_udf_timestamps(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime - schema = StructType([ - StructField("idx", LongType(), True), - StructField("timestamp", TimestampType(), True)]) - data = [(0, datetime(1969, 1, 1, 1, 1, 1)), - (1, datetime(2012, 2, 2, 2, 2, 2)), - (2, None), - (3, datetime(2100, 3, 3, 3, 3, 3))] - - df = self.spark.createDataFrame(data, schema=schema) - - # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc - f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) - df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) - - @pandas_udf(returnType=StringType()) - def check_data(idx, timestamp, timestamp_copy): - import pandas as pd - msgs = [] - is_equal = timestamp.isnull() # use this array to check values are equal - for i in range(len(idx)): - # Check that timestamps are as expected in the UDF - if (is_equal[i] and data[idx[i]][1] is None) or \ - timestamp[i].to_pydatetime() == data[idx[i]][1]: - msgs.append(None) - else: - msgs.append( - "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')" - % (timestamp[i], idx[i], data[idx[i]][1])) - return pd.Series(msgs) - - result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"), - col("timestamp_copy"))).collect() - # Check that collection values are correct - self.assertEquals(len(data), len(result)) - for i in range(len(result)): - self.assertEquals(data[i][1], result[i][1]) # "timestamp" col - self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col - self.assertIsNone(result[i][3]) # "check_data" col - - def test_vectorized_udf_return_timestamp_tz(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - - @pandas_udf(returnType=TimestampType()) - def gen_timestamps(id): - ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] - return pd.Series(ts) - - result = df.withColumn("ts", gen_timestamps(col("id"))).collect() - spark_ts_t = TimestampType() - for r in result: - i, ts = r - ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() - expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) - self.assertEquals(expected, ts) - - def test_vectorized_udf_check_config(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): - df = self.spark.range(10, numPartitions=1) - - @pandas_udf(returnType=LongType()) - def check_records_per_batch(x): - return pd.Series(x.size).repeat(x.size) - - result = df.select(check_records_per_batch(col("id"))).collect() - for (r,) in result: - self.assertTrue(r <= 3) - - def test_vectorized_udf_timestamps_respect_session_timezone(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime - import pandas as pd - schema = StructType([ - StructField("idx", LongType(), True), - StructField("timestamp", TimestampType(), True)]) - data = [(1, datetime(1969, 1, 1, 1, 1, 1)), - (2, datetime(2012, 2, 2, 2, 2, 2)), - (3, None), - (4, datetime(2100, 3, 3, 3, 3, 3))] - df = self.spark.createDataFrame(data, schema=schema) - - f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType()) - internal_value = pandas_udf( - lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() - - self.assertNotEqual(result_ny, result_la) - self.assertEqual(result_ny, result_la_corrected) - - def test_nondeterministic_vectorized_udf(self): - # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf, pandas_udf, col - - @pandas_udf('double') - def plus_ten(v): - return v + 10 - random_udf = self.nondeterministic_vectorized_udf - - df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) - result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() - - self.assertEqual(random_udf.deterministic, False) - self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - - def test_nondeterministic_vectorized_udf_in_aggregate(self): - from pyspark.sql.functions import pandas_udf, sum - - df = self.spark.range(10) - random_udf = self.nondeterministic_vectorized_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): - df.groupby(df.id).agg(sum(random_udf(df.id))).collect() - with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): - df.agg(sum(random_udf(df.id))).collect() - - def test_register_vectorized_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, col, expr - df = self.spark.range(10).select( - col('id').cast('int').alias('a'), - col('id').cast('int').alias('b')) - original_add = pandas_udf(lambda x, y: x + y, IntegerType()) - self.assertEqual(original_add.deterministic, True) - self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - new_add = self.spark.catalog.registerFunction("add1", original_add) - res1 = df.select(new_add(col('a'), col('b'))) - res2 = self.spark.sql( - "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") - expected = df.select(expr('a + b')) - self.assertEquals(expected.collect(), res1.collect()) - self.assertEquals(expected.collect(), res2.collect()) - - # Regression test for SPARK-23314 - def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - foo_udf = pandas_udf(lambda x: x, 'timestamp') - result = df.withColumn('time', foo_udf(df.time)) - self.assertEquals(df.collect(), result.collect()) - - @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") - def test_type_annotation(self): - from pyspark.sql.functions import pandas_udf - # Regression test to check if type hints can be used. See SPARK-23569. - # Note that it throws an error during compilation in lower Python versions if 'exec' - # is not used. Also, note that we explicitly use another dictionary to avoid modifications - # in the current 'locals()'. - # - # Hyukjin: I think it's an ugly way to test issues about syntax specific in - # higher versions of Python, which we shouldn't encourage. This was the last resort - # I could come up with at that time. - _locals = {} - exec( - "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", - _locals) - df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) - self.assertEqual(df.first()[0], 0) - - def test_mixed_udf(self): - import pandas as pd - from pyspark.sql.functions import col, udf, pandas_udf - - df = self.spark.range(0, 1).toDF('v') - - # Test mixture of multiple UDFs and Pandas UDFs. - - @udf('int') - def f1(x): - assert type(x) == int - return x + 1 - - @pandas_udf('int') - def f2(x): - assert type(x) == pd.Series - return x + 10 - - @udf('int') - def f3(x): - assert type(x) == int - return x + 100 - - @pandas_udf('int') - def f4(x): - assert type(x) == pd.Series - return x + 1000 - - # Test single expression with chained UDFs - df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) - df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) - df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) - df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) - df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) - - expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) - expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) - expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) - expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) - expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) - - self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) - self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) - self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) - self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) - self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) - - # Test multiple mixed UDF expressions in a single projection - df_multi_1 = df \ - .withColumn('f1', f1(col('v'))) \ - .withColumn('f2', f2(col('v'))) \ - .withColumn('f3', f3(col('v'))) \ - .withColumn('f4', f4(col('v'))) \ - .withColumn('f2_f1', f2(col('f1'))) \ - .withColumn('f3_f1', f3(col('f1'))) \ - .withColumn('f4_f1', f4(col('f1'))) \ - .withColumn('f3_f2', f3(col('f2'))) \ - .withColumn('f4_f2', f4(col('f2'))) \ - .withColumn('f4_f3', f4(col('f3'))) \ - .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ - .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ - .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ - .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ - .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) - - # Test mixed udfs in a single expression - df_multi_2 = df \ - .withColumn('f1', f1(col('v'))) \ - .withColumn('f2', f2(col('v'))) \ - .withColumn('f3', f3(col('v'))) \ - .withColumn('f4', f4(col('v'))) \ - .withColumn('f2_f1', f2(f1(col('v')))) \ - .withColumn('f3_f1', f3(f1(col('v')))) \ - .withColumn('f4_f1', f4(f1(col('v')))) \ - .withColumn('f3_f2', f3(f2(col('v')))) \ - .withColumn('f4_f2', f4(f2(col('v')))) \ - .withColumn('f4_f3', f4(f3(col('v')))) \ - .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ - .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ - .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ - .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ - .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) - - expected = df \ - .withColumn('f1', df['v'] + 1) \ - .withColumn('f2', df['v'] + 10) \ - .withColumn('f3', df['v'] + 100) \ - .withColumn('f4', df['v'] + 1000) \ - .withColumn('f2_f1', df['v'] + 11) \ - .withColumn('f3_f1', df['v'] + 101) \ - .withColumn('f4_f1', df['v'] + 1001) \ - .withColumn('f3_f2', df['v'] + 110) \ - .withColumn('f4_f2', df['v'] + 1010) \ - .withColumn('f4_f3', df['v'] + 1100) \ - .withColumn('f3_f2_f1', df['v'] + 111) \ - .withColumn('f4_f2_f1', df['v'] + 1011) \ - .withColumn('f4_f3_f1', df['v'] + 1101) \ - .withColumn('f4_f3_f2', df['v'] + 1110) \ - .withColumn('f4_f3_f2_f1', df['v'] + 1111) - - self.assertEquals(expected.collect(), df_multi_1.collect()) - self.assertEquals(expected.collect(), df_multi_2.collect()) - - def test_mixed_udf_and_sql(self): - import pandas as pd - from pyspark.sql import Column - from pyspark.sql.functions import udf, pandas_udf - - df = self.spark.range(0, 1).toDF('v') - - # Test mixture of UDFs, Pandas UDFs and SQL expression. - - @udf('int') - def f1(x): - assert type(x) == int - return x + 1 - - def f2(x): - assert type(x) == Column - return x + 10 - - @pandas_udf('int') - def f3(x): - assert type(x) == pd.Series - return x + 100 - - df1 = df.withColumn('f1', f1(df['v'])) \ - .withColumn('f2', f2(df['v'])) \ - .withColumn('f3', f3(df['v'])) \ - .withColumn('f1_f2', f1(f2(df['v']))) \ - .withColumn('f1_f3', f1(f3(df['v']))) \ - .withColumn('f2_f1', f2(f1(df['v']))) \ - .withColumn('f2_f3', f2(f3(df['v']))) \ - .withColumn('f3_f1', f3(f1(df['v']))) \ - .withColumn('f3_f2', f3(f2(df['v']))) \ - .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ - .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ - .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ - .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ - .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ - .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) - - expected = df.withColumn('f1', df['v'] + 1) \ - .withColumn('f2', df['v'] + 10) \ - .withColumn('f3', df['v'] + 100) \ - .withColumn('f1_f2', df['v'] + 11) \ - .withColumn('f1_f3', df['v'] + 101) \ - .withColumn('f2_f1', df['v'] + 11) \ - .withColumn('f2_f3', df['v'] + 110) \ - .withColumn('f3_f1', df['v'] + 101) \ - .withColumn('f3_f2', df['v'] + 110) \ - .withColumn('f1_f2_f3', df['v'] + 111) \ - .withColumn('f1_f3_f2', df['v'] + 111) \ - .withColumn('f2_f1_f3', df['v'] + 111) \ - .withColumn('f2_f3_f1', df['v'] + 111) \ - .withColumn('f3_f1_f2', df['v'] + 111) \ - .withColumn('f3_f2_f1', df['v'] + 111) - - self.assertEquals(expected.collect(), df1.collect()) - - # SPARK-24721 - @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf(self): - # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF - # This needs to a separate test because Arrow dependency is optional - import pandas as pd - import numpy as np - from pyspark.sql.functions import pandas_udf, lit, col - - path = tempfile.mkdtemp() - shutil.rmtree(path) - - try: - self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') - datasource_df = self.spark.read \ - .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load().toDF('i') - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load().toDF('i', 'j') - - c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) - c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) - - f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) - f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c1) - expected = df.withColumn('c', lit(2)) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c2) - expected = df.withColumn('c', col('i') + 1) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [f1, f2]: - result = df.filter(f) - self.assertEquals(0, result.count()) - finally: - shutil.rmtree(path) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class GroupedMapPandasUDFTests(ReusedSQLTestCase): - - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))).drop('vs') - - def test_supported_types(self): - from decimal import Decimal - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType - - values = [ - 1, 2, 3, - 4, 5, 1.1, - 2.2, Decimal(1.123), - [1, 2, 2], True, 'hello' - ] - output_fields = [ - ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()), - ('int', IntegerType()), ('long', LongType()), ('float', FloatType()), - ('double', DoubleType()), ('decim', DecimalType(10, 3)), - ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()) - ] - - # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): - values.append(bytearray([0x01, 0x02])) - output_fields.append(('bin', BinaryType())) - - output_schema = StructType([StructField(*x) for x in output_fields]) - df = self.spark.createDataFrame([values], schema=output_schema) - - # Different forms of group map pandas UDF, results of these are the same - udf1 = pandas_udf( - lambda pdf: pdf.assign( - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - udf2 = pandas_udf( - lambda _, pdf: pdf.assign( - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - udf3 = pandas_udf( - lambda key, pdf: pdf.assign( - id=key[0], - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - result1 = df.groupby('id').apply(udf1).sort('id').toPandas() - expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) - - result2 = df.groupby('id').apply(udf2).sort('id').toPandas() - expected2 = expected1 - - result3 = df.groupby('id').apply(udf3).sort('id').toPandas() - expected3 = expected1 - - self.assertPandasEqual(expected1, result1) - self.assertPandasEqual(expected2, result2) - self.assertPandasEqual(expected3, result3) - - def test_array_type_correct(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - - df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") - - output_schema = StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType()))]) - - udf = pandas_udf( - lambda pdf: pdf, - output_schema, - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_register_grouped_map_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - ValueError, - 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): - self.spark.catalog.registerFunction("foo_udf", foo_udf) - - def test_decorator(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUPED_MAP - ) - def foo(pdf): - return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) - - result = df.groupby('id').apply(foo).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_coerce(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v double', - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(foo).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - expected = expected.assign(v=expected.v.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, norm double', - PandasUDFType.GROUPED_MAP - ) - def normalize(pdf): - v = pdf.v - return pdf.assign(norm=(v - v.mean()) / v.std()) - - result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() - pdf = df.toPandas() - expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) - expected = expected.sort_values(['id', 'v']).reset_index(drop=True) - expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, norm double', - PandasUDFType.GROUPED_MAP - ) - def normalize(pdf): - v = pdf.v - return pdf.assign(norm=(v - v.mean()) / v.std()) - - result = df.groupby().apply(normalize).sort('id', 'v').toPandas() - pdf = df.toPandas() - expected = normalize.func(pdf) - expected = expected.sort_values(['id', 'v']).reset_index(drop=True) - expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo_udf = pandas_udf( - lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*MapType'): - pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP) - - def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType - df = self.data - - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(df.v + 1) - with self.assertRaisesRegexp(ValueError, 'Invalid function'): - df.groupby('id').apply( - pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) - - def test_unsupported_types(self): - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType - - common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' - unsupported_types = [ - StructField('map', MapType(StringType(), IntegerType())), - StructField('arr_ts', ArrayType(TimestampType())), - StructField('null', NullType()), - ] - - # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - unsupported_types.append(StructField('bin', BinaryType())) - - for unsupported_type in unsupported_types: - schema = StructType([StructField('id', LongType(), True), unsupported_type]) - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, common_err_msg): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) - - # Regression test for SPARK-23314 - def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) - result = df.groupby('time').apply(foo_udf).sort('time') - self.assertPandasEqual(df.toPandas(), result.toPandas()) - - def test_udf_with_key(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - pdf = df.toPandas() - - def foo1(key, pdf): - import numpy as np - assert type(key) == tuple - assert type(key[0]) == np.int64 - - return pdf.assign(v1=key[0], - v2=pdf.v * key[0], - v3=pdf.v * pdf.id, - v4=pdf.v * pdf.id.mean()) - - def foo2(key, pdf): - import numpy as np - assert type(key) == tuple - assert type(key[0]) == np.int64 - assert type(key[1]) == np.int32 - - return pdf.assign(v1=key[0], - v2=key[1], - v3=pdf.v * key[0], - v4=pdf.v + key[1]) - - def foo3(key, pdf): - assert type(key) == tuple - assert len(key) == 0 - return pdf.assign(v1=pdf.v * pdf.id) - - # v2 is int because numpy.int64 * pd.Series results in pd.Series - # v3 is long because pd.Series * pd.Series results in pd.Series - udf1 = pandas_udf( - foo1, - 'id long, v int, v1 long, v2 int, v3 long, v4 double', - PandasUDFType.GROUPED_MAP) - - udf2 = pandas_udf( - foo2, - 'id long, v int, v1 long, v2 int, v3 int, v4 int', - PandasUDFType.GROUPED_MAP) - - udf3 = pandas_udf( - foo3, - 'id long, v int, v1 long', - PandasUDFType.GROUPED_MAP) - - # Test groupby column - result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() - expected1 = pdf.groupby('id')\ - .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected1, result1) - - # Test groupby expression - result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() - expected2 = pdf.groupby(pdf.id % 2)\ - .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected2, result2) - - # Test complex groupby - result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() - expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ - .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected3, result3) - - # Test empty groupby - result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() - expected4 = udf3.func((), pdf) - self.assertPandasEqual(expected4, result4) - - def test_column_order(self): - from collections import OrderedDict - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - # Helper function to set column names from a list - def rename_pdf(pdf, names): - pdf.rename(columns={old: new for old, new in - zip(pd_result.columns, names)}, inplace=True) - - df = self.data - grouped_df = df.groupby('id') - grouped_pdf = df.toPandas().groupby('id') - - # Function returns a pdf with required column names, but order could be arbitrary using dict - def change_col_order(pdf): - # Constructing a DataFrame from a dict should result in the same order, - # but use from_items to ensure the pdf column order is different than schema - return pd.DataFrame.from_items([ - ('id', pdf.id), - ('u', pdf.v * 2), - ('v', pdf.v)]) - - ordered_udf = pandas_udf( - change_col_order, - 'id long, v int, u int', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result should assign columns by name from the pdf - result = grouped_df.apply(ordered_udf).sort('id', 'v')\ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(change_col_order) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - # Function returns a pdf with positional columns, indexed by range - def range_col_order(pdf): - # Create a DataFrame with positional columns, fix types to long - return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') - - range_udf = pandas_udf( - range_col_order, - 'id long, u long, v long', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result uses positional columns from the pdf - result = grouped_df.apply(range_udf).sort('id', 'v') \ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(range_col_order) - rename_pdf(pd_result, ['id', 'u', 'v']) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - # Function returns a pdf with columns indexed with integers - def int_index(pdf): - return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) - - int_index_udf = pandas_udf( - int_index, - 'id long, u int, v int', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result should assign columns by position of integer index - result = grouped_df.apply(int_index_udf).sort('id', 'v') \ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(int_index) - rename_pdf(pd_result, ['id', 'u', 'v']) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) - def column_name_typo(pdf): - return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) - - @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) - def invalid_positional_types(pdf): - return pd.DataFrame([(u'a', 1.2)]) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): - grouped_df.apply(column_name_typo).collect() - with self.assertRaisesRegexp(Exception, "No cast implemented"): - grouped_df.apply(invalid_positional_types).collect() - - def test_positional_assignment_conf(self): - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with self.sql_conf({ - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): - - @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) - def foo(_): - return pd.DataFrame([('hi', 1)], columns=['x', 'y']) - - df = self.data - result = df.groupBy('id').apply(foo).select('a', 'b').collect() - for r in result: - self.assertEqual(r.a, 'hi') - self.assertEqual(r.b, 1) - - def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) - def dummy_pandas_udf(df): - return df[['key', 'col']] - - df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), - Row(key=2, col='C')]) - df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) - - # this was throwing an AnalysisException before SPARK-24208 - res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) - self.assertEquals(res.count(), 5) - - def test_mixed_scalar_udfs_followed_by_grouby_apply(self): - import pandas as pd - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - - df = self.spark.range(0, 10).toDF('v1') - df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ - .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) - - result = df.groupby() \ - .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), - 'sum int', - PandasUDFType.GROUPED_MAP)) - - self.assertEquals(result.collect()[0]['sum'], 165) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class GroupedAggPandasUDFTests(ReusedSQLTestCase): - - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))) \ - .drop('vs') \ - .withColumn('w', lit(1.0)) - - @property - def python_plus_one(self): - from pyspark.sql.functions import udf - - @udf('double') - def plus_one(v): - assert isinstance(v, (int, float)) - return v + 1 - return plus_one - - @property - def pandas_scalar_plus_two(self): - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.SCALAR) - def plus_two(v): - assert isinstance(v, pd.Series) - return v + 2 - return plus_two - - @property - def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def avg(v): - return v.mean() - return avg - - @property - def pandas_agg_sum_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def sum(v): - return v.sum() - return sum - - @property - def pandas_agg_weighted_mean_udf(self): - import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def weighted_mean(v, w): - return np.average(v, weights=w) - return weighted_mean - - def test_manual(self): - from pyspark.sql.functions import pandas_udf, array - - df = self.data - sum_udf = self.pandas_agg_sum_udf - mean_udf = self.pandas_agg_mean_udf - mean_arr_udf = pandas_udf( - self.pandas_agg_mean_udf.func, - ArrayType(self.pandas_agg_mean_udf.returnType), - self.pandas_agg_mean_udf.evalType) - - result1 = df.groupby('id').agg( - sum_udf(df.v), - mean_udf(df.v), - mean_arr_udf(array(df.v))).sort('id') - expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5, [24.5]], - [1, 255.0, 25.5, [25.5]], - [2, 265.0, 26.5, [26.5]], - [3, 275.0, 27.5, [27.5]], - [4, 285.0, 28.5, [28.5]], - [5, 295.0, 29.5, [29.5]], - [6, 305.0, 30.5, [30.5]], - [7, 315.0, 31.5, [31.5]], - [8, 325.0, 32.5, [32.5]], - [9, 335.0, 33.5, [33.5]]], - ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_basic(self): - from pyspark.sql.functions import col, lit, sum, mean - - df = self.data - weighted_mean_udf = self.pandas_agg_weighted_mean_udf - - # Groupby one column and aggregate one UDF with literal - result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') - expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - # Groupby one expression and aggregate one UDF with literal - result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ - .sort(df.id + 1) - expected2 = df.groupby((col('id') + 1))\ - .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - # Groupby one column and aggregate one UDF without literal - result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') - expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - - # Groupby one expression and aggregate one UDF without literal - result4 = df.groupby((col('id') + 1).alias('id'))\ - .agg(weighted_mean_udf(df.v, df.w))\ - .sort('id') - expected4 = df.groupby((col('id') + 1).alias('id'))\ - .agg(mean(df.v).alias('weighted_mean(v, w)'))\ - .sort('id') - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - - def test_unsupported_types(self): - from pyspark.sql.types import DoubleType, MapType - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - pandas_udf( - lambda x: x, - ArrayType(ArrayType(TimestampType())), - PandasUDFType.GROUPED_AGG) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return v.mean(), v.std() - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return {v.mean(): v.std()} - - def test_alias(self): - from pyspark.sql.functions import mean - - df = self.data - mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_sql(self): - """ - Test mixing group aggregate pandas UDF with sql expression. - """ - from pyspark.sql.functions import sum, mean - - df = self.data - sum_udf = self.pandas_agg_sum_udf - - # Mix group aggregate pandas UDF with sql expression - result1 = (df.groupby('id') - .agg(sum_udf(df.v) + 1) - .sort('id')) - expected1 = (df.groupby('id') - .agg(sum(df.v) + 1) - .sort('id')) - - # Mix group aggregate pandas UDF with sql expression (order swapped) - result2 = (df.groupby('id') - .agg(sum_udf(df.v + 1)) - .sort('id')) - - expected2 = (df.groupby('id') - .agg(sum(df.v + 1)) - .sort('id')) - - # Wrap group aggregate pandas UDF with two sql expressions - result3 = (df.groupby('id') - .agg(sum_udf(df.v + 1) + 2) - .sort('id')) - expected3 = (df.groupby('id') - .agg(sum(df.v + 1) + 2) - .sort('id')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - - def test_mixed_udfs(self): - """ - Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. - """ - from pyspark.sql.functions import sum, mean - - df = self.data - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - sum_udf = self.pandas_agg_sum_udf - - # Mix group aggregate pandas UDF and python UDF - result1 = (df.groupby('id') - .agg(plus_one(sum_udf(df.v))) - .sort('id')) - expected1 = (df.groupby('id') - .agg(plus_one(sum(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and python UDF (order swapped) - result2 = (df.groupby('id') - .agg(sum_udf(plus_one(df.v))) - .sort('id')) - expected2 = (df.groupby('id') - .agg(sum(plus_one(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and scalar pandas UDF - result3 = (df.groupby('id') - .agg(sum_udf(plus_two(df.v))) - .sort('id')) - expected3 = (df.groupby('id') - .agg(sum(plus_two(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) - result4 = (df.groupby('id') - .agg(plus_two(sum_udf(df.v))) - .sort('id')) - expected4 = (df.groupby('id') - .agg(plus_two(sum(df.v))) - .sort('id')) - - # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby - result5 = (df.groupby(plus_one(df.id)) - .agg(plus_one(sum_udf(plus_one(df.v)))) - .sort('plus_one(id)')) - expected5 = (df.groupby(plus_one(df.id)) - .agg(plus_one(sum(plus_one(df.v)))) - .sort('plus_one(id)')) - - # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in - # groupby - result6 = (df.groupby(plus_two(df.id)) - .agg(plus_two(sum_udf(plus_two(df.v)))) - .sort('plus_two(id)')) - expected6 = (df.groupby(plus_two(df.id)) - .agg(plus_two(sum(plus_two(df.v)))) - .sort('plus_two(id)')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) - self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) - - def test_multiple_udfs(self): - """ - Test multiple group aggregate pandas UDFs in one agg function. - """ - from pyspark.sql.functions import col, lit, sum, mean - - df = self.data - mean_udf = self.pandas_agg_mean_udf - sum_udf = self.pandas_agg_sum_udf - weighted_mean_udf = self.pandas_agg_weighted_mean_udf - - result1 = (df.groupBy('id') - .agg(mean_udf(df.v), - sum_udf(df.v), - weighted_mean_udf(df.v, df.w)) - .sort('id') - .toPandas()) - expected1 = (df.groupBy('id') - .agg(mean(df.v), - sum(df.v), - mean(df.v).alias('weighted_mean(v, w)')) - .sort('id') - .toPandas()) - - self.assertPandasEqual(expected1, result1) - - def test_complex_groupby(self): - from pyspark.sql.functions import lit, sum - - df = self.data - sum_udf = self.pandas_agg_sum_udf - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - - # groupby one expression - result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) - expected1 = df.groupby(df.v % 2).agg(sum(df.v)) - - # empty groupby - result2 = df.groupby().agg(sum_udf(df.v)) - expected2 = df.groupby().agg(sum(df.v)) - - # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) - - # groupby one python UDF - result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) - expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) - - # groupby one scalar pandas UDF - result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) - expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) - - # groupby one expression and one python UDF - result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) - expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) - - # groupby one expression and one scalar pandas UDF - result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') - expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) - self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) - self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) - - def test_complex_expressions(self): - from pyspark.sql.functions import col, sum - - df = self.data - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - sum_udf = self.pandas_agg_sum_udf - - # Test complex expressions with sql expression, python UDF and - # group aggregate pandas UDF - result1 = (df.withColumn('v1', plus_one(df.v)) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum_udf(col('v')), - sum_udf(col('v1') + 3), - sum_udf(col('v2')) + 5, - plus_one(sum_udf(col('v1'))), - sum_udf(plus_one(col('v2')))) - .sort('id') - .toPandas()) - - expected1 = (df.withColumn('v1', df.v + 1) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum(col('v')), - sum(col('v1') + 3), - sum(col('v2')) + 5, - plus_one(sum(col('v1'))), - sum(plus_one(col('v2')))) - .sort('id') - .toPandas()) - - # Test complex expressions with sql expression, scala pandas UDF and - # group aggregate pandas UDF - result2 = (df.withColumn('v1', plus_one(df.v)) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum_udf(col('v')), - sum_udf(col('v1') + 3), - sum_udf(col('v2')) + 5, - plus_two(sum_udf(col('v1'))), - sum_udf(plus_two(col('v2')))) - .sort('id') - .toPandas()) - - expected2 = (df.withColumn('v1', df.v + 1) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum(col('v')), - sum(col('v1') + 3), - sum(col('v2')) + 5, - plus_two(sum(col('v1'))), - sum(plus_two(col('v2')))) - .sort('id') - .toPandas()) - - # Test sequential groupby aggregate - result3 = (df.groupby('id') - .agg(sum_udf(df.v).alias('v')) - .groupby('id') - .agg(sum_udf(col('v'))) - .sort('id') - .toPandas()) - - expected3 = (df.groupby('id') - .agg(sum(df.v).alias('v')) - .groupby('id') - .agg(sum(col('v'))) - .sort('id') - .toPandas()) - - self.assertPandasEqual(expected1, result1) - self.assertPandasEqual(expected2, result2) - self.assertPandasEqual(expected3, result3) - - def test_retain_group_columns(self): - from pyspark.sql.functions import sum, lit, col - with self.sql_conf({"spark.sql.retainGroupColumns": False}): - df = self.data - sum_udf = self.pandas_agg_sum_udf - - result1 = df.groupby(df.id).agg(sum_udf(df.v)) - expected1 = df.groupby(df.id).agg(sum(df.v)) - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = self.data - - array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) - result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) - self.assertEquals(result1.first()['v2'], [1.0, 2.0]) - - def test_invalid_args(self): - from pyspark.sql.functions import mean - - df = self.data - plus_one = self.python_plus_one - mean_udf = self.pandas_agg_mean_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'nor.*aggregate function'): - df.groupby(df.id).agg(plus_one(df.v)).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'aggregate function.*argument.*aggregate function'): - df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'mixture.*aggregate function.*group aggregate pandas UDF'): - df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - - def test_register_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - - sum_pandas_udf = pandas_udf( - lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - - self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) - self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" - actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) - expected = [1, 5] - self.assertEqual(actual, expected) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class WindowPandasUDFTests(ReusedSQLTestCase): - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))) \ - .drop('vs') \ - .withColumn('w', lit(1.0)) - - @property - def python_plus_one(self): - from pyspark.sql.functions import udf - return udf(lambda v: v + 1, 'double') - - @property - def pandas_scalar_time_two(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - return pandas_udf(lambda v: v * 2, 'double') - - @property - def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def avg(v): - return v.mean() - return avg - - @property - def pandas_agg_max_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def max(v): - return v.max() - return max - - @property - def pandas_agg_min_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def min(v): - return v.min() - return min - - @property - def unbounded_window(self): - return Window.partitionBy('id') \ - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - - @property - def ordered_window(self): - return Window.partitionBy('id').orderBy('v') - - @property - def unpartitioned_window(self): - return Window.partitionBy() - - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max - - df = self.data - w = self.unbounded_window - - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) - expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) - - result2 = df.select(mean_udf(df['v']).over(w)) - expected2 = df.select(mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_multiple_udfs(self): - from pyspark.sql.functions import max, min, mean - - df = self.data - w = self.unbounded_window - - result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ - .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ - .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) - - expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ - .withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('min_w', min(df['w']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_replace_existing(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - - result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) - expected1 = df.withColumn('v', mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_sql(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) - expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_udf(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - - plus_one = self.python_plus_one - time_two = self.pandas_scalar_time_two - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn( - 'v2', - plus_one(mean_udf(plus_one(df['v'])).over(w))) - expected1 = df.withColumn( - 'v2', - plus_one(mean(plus_one(df['v'])).over(w))) - - result2 = df.withColumn( - 'v2', - time_two(mean_udf(time_two(df['v'])).over(w))) - expected2 = df.withColumn( - 'v2', - time_two(mean(time_two(df['v'])).over(w))) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_without_partitionBy(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unpartitioned_window - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) - expected1 = df.withColumn('v2', mean(df['v']).over(w)) - - result2 = df.select(mean_udf(df['v']).over(w)) - expected2 = df.select(mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_mixed_sql_and_udf(self): - from pyspark.sql.functions import max, min, rank, col - - df = self.data - w = self.unbounded_window - ow = self.ordered_window - max_udf = self.pandas_agg_max_udf - min_udf = self.pandas_agg_min_udf - - result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) - expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) - - # Test mixing sql window function and window udf in the same expression - result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) - expected2 = expected1 - - # Test chaining sql aggregate function and udf - result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('min_v', min(df['v']).over(w)) \ - .withColumn('v_diff', col('max_v') - col('min_v')) \ - .drop('max_v', 'min_v') - expected3 = expected1 - - # Test mixing sql window function and udf - result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) - expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - - def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = self.data - w = self.unbounded_window - - array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) - result1 = df.withColumn('v2', array_udf(df['v']).over(w)) - self.assertEquals(result1.first()['v2'], [1.0, 2.0]) - - def test_invalid_args(self): - from pyspark.sql.functions import mean, pandas_udf, PandasUDFType - - df = self.data - w = self.unbounded_window - ow = self.ordered_window - mean_udf = self.pandas_agg_mean_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*not supported within a window function'): - foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) - df.withColumn('v2', foo_udf(df['v']).over(w)) - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*Only unbounded window frame is supported.*'): - df.withColumn('mean_v', mean_udf(df['v']).over(ow)) - - -if __name__ == "__main__": - from pyspark.sql.tests import * - - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.sql_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 6e75e82d58009..751c2eddf5cb6 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -46,7 +46,7 @@ def setUpClass(cls): # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" + tz = "UTC" os.environ["TZ"] = tz time.tzset() @@ -232,7 +232,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v + result_la_corrected = [Row(**{k: v + timedelta(hours=5) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2f585a3725988..b303398850394 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -42,7 +42,7 @@ def setUpClass(cls): # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" + tz = "UTC" os.environ["TZ"] = tz time.tzset() @@ -503,9 +503,9 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): .withColumn("internal_value", internal_value(col("timestamp"))) result_la = df_la.select(col("idx"), col("internal_value")).collect() # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + diff = 5 * 60 * 60 * 1000 * 1000 * 1000 result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + df_la.select(col("idx"), col("tscopy"), col("internal_value") - diff).collect() with self.sql_conf({ "spark.sql.execution.pandas.respectSessionTimeZone": True, diff --git a/python/pyspark/streaming/tests/test_dstream.py b/python/pyspark/streaming/tests/test_dstream.py index 09a35552563b2..d14e346b7a688 100644 --- a/python/pyspark/streaming/tests/test_dstream.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -17,153 +17,15 @@ import operator import os import shutil -<<<<<<< HEAD:python/pyspark/streaming/tests.py -import unishark -======= import tempfile import time import unittest ->>>>>>> 87bd9c75df:python/pyspark/streaming/tests/test_dstream.py from functools import reduce from itertools import chain -<<<<<<< HEAD:python/pyspark/streaming/tests.py -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version >= "3": - long = int - -from pyspark.context import SparkConf, SparkContext, RDD -from pyspark.storagelevel import StorageLevel -from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream -from pyspark.streaming.listener import StreamingListener - - -class PySparkStreamingTestCase(unittest.TestCase): - - timeout = 30 # seconds - duration = .5 - - @classmethod - def setUpClass(cls): - class_name = cls.__name__ - conf = SparkConf().set("spark.default.parallelism", 1) - cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir(tempfile.mkdtemp()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - # Clean up in the JVM just in case there has been some issues in Python API - try: - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() - except: - pass - - def setUp(self): - self.ssc = StreamingContext(self.sc, self.duration) - - def tearDown(self): - if self.ssc is not None: - self.ssc.stop(False) - # Clean up in the JVM just in case there has been some issues in Python API - try: - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop(False) - except: - pass - - def wait_for(self, result, n): - start_time = time.time() - while len(result) < n and time.time() - start_time < self.timeout: - time.sleep(0.01) - if len(result) < n: - print("timeout after", self.timeout) - - def _take(self, dstream, n): - """ - Return the first `n` elements in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - dstream.foreachRDD(take) - - self.ssc.start() - self.wait_for(results, n) - return results - - def _collect(self, dstream, n, block=True): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - if rdd and len(result) < n: - r = rdd.collect() - if r: - result.append(r) - - dstream.foreachRDD(get_output) - - if not block: - return result - - self.ssc.start() - self.wait_for(result, n) - return result - - def _test_func(self, input, func, expected, sort=False, input2=None): - """ - @param input: dataset for the test. This should be list of lists. - @param func: wrapped function. This function should return PythonDStream object. - @param expected: expected output for this testcase. - """ - if not isinstance(input[0], RDD): - input = [self.sc.parallelize(d, 1) for d in input] - input_stream = self.ssc.queueStream(input) - if input2 and not isinstance(input2[0], RDD): - input2 = [self.sc.parallelize(d, 1) for d in input2] - input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None - - # Apply test function to stream. - if input2: - stream = func(input_stream, input_stream2) - else: - stream = func(input_stream) - - result = self._collect(stream, len(expected)) - if sort: - self._sort_result_based_on_key(result) - self._sort_result_based_on_key(expected) - self.assertEqual(expected, result) - - def _sort_result_based_on_key(self, outputs): - """Sort the list based on first value.""" - for output in outputs: - output.sort(key=lambda x: x[0]) -======= from pyspark import SparkConf, SparkContext, RDD from pyspark.streaming import StreamingContext from pyspark.testing.streamingutils import PySparkStreamingTestCase ->>>>>>> 87bd9c75df:python/pyspark/streaming/tests/test_dstream.py class BasicOperationTests(PySparkStreamingTestCase): @@ -769,53 +631,6 @@ def check_output(n): if __name__ == "__main__": -<<<<<<< HEAD:python/pyspark/streaming/tests.py - from pyspark.streaming.tests import * - kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - - if kinesis_asl_assembly_jar is None: - kinesis_jar_present = False - jars_args = "" - else: - kinesis_jar_present = True - jars_args = "--jars %s" % kinesis_asl_assembly_jar - - existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") - os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) - testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - StreamingListenerTests] - - if kinesis_jar_present is True: - testcases.append(KinesisStreamTests) - elif are_kinesis_tests_enabled is False: - sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " - "streaming-kinesis-asl-assembly/assembly' or " - "'build/mvn -Pkinesis-asl package' before running this test.") - else: - raise Exception( - ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % _kinesis_asl_assembly_dir()) + - "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/package streaming-kinesis-asl-assembly/assembly'" - "or 'build/mvn -Pkinesis-asl package' before running this test.") - - sys.stderr.write("Running tests: %s \n" % (str(testcases))) - failed = False - for testcase in testcases: - sys.stderr.write("[Running %s]\n" % (testcase)) - tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - runner = unishark.BufferedTestRunner( - verbosity=2, - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.streaming_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - - result = runner.run(tests) - if not result.wasSuccessful(): - failed = True - sys.exit(failed) -======= from pyspark.streaming.tests.test_dstream import * try: @@ -823,4 +638,3 @@ def check_output(n): unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) except ImportError: unittest.main(verbosity=2) ->>>>>>> 87bd9c75df:python/pyspark/streaming/tests/test_dstream.py diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py deleted file mode 100644 index c15d443ebbba9..0000000000000 --- a/python/pyspark/tests.py +++ /dev/null @@ -1,2522 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for PySpark; additional tests are implemented as doctests in -individual modules. -""" - -from array import array -from glob import glob -import os -import re -import shutil -import subprocess -import sys -import tempfile -import time -import zipfile -import random -import threading -import hashlib - -from py4j.protocol import Py4JJavaError -xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - if sys.version_info[0] >= 3: - xrange = range - basestring = str - -import unishark - -if sys.version >= "3": - from io import StringIO -else: - from StringIO import StringIO - - -from pyspark import keyword_only -from pyspark.conf import SparkConf -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ - PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ - FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter -from pyspark import shuffle -from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import BarrierTaskContext, TaskContext - -_have_scipy = False -_have_numpy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass -try: - import numpy as np - _have_numpy = True -except: - # No NumPy, but that's okay, we'll skip those tests - pass - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -class MergerTests(unittest.TestCase): - - def setUp(self): - self.N = 1 << 12 - self.l = [i for i in xrange(self.N)] - self.data = list(zip(self.l, self.l)) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) - - def test_small_dataset(self): - m = ExternalMerger(self.agg, 1000) - m.mergeValues(self.data) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - def test_medium_dataset(self): - m = ExternalMerger(self.agg, 20) - m.mergeValues(self.data) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N)) * 3) - - def test_huge_dataset(self): - m = ExternalMerger(self.agg, 5, partitions=3) - m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.items()), - self.N * 10) - m._cleanup() - - def test_group_by_key(self): - - def gen_data(N, step): - for i in range(1, N + 1, step): - for j in range(i): - yield (i, [j]) - - def gen_gs(N, step=1): - return shuffle.GroupByKey(gen_data(N, step)) - - self.assertEqual(1, len(list(gen_gs(1)))) - self.assertEqual(2, len(list(gen_gs(2)))) - self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) - - for k, vs in gen_gs(50002, 10000): - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - ser = PickleSerializer() - l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) - for k, vs in l: - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - def test_stopiteration_is_raised(self): - - def stopit(*args, **kwargs): - raise StopIteration() - - def legit_create_combiner(x): - return [x] - - def legit_merge_value(x, y): - return x.append(y) or x - - def legit_merge_combiners(x, y): - return x.extend(y) or x - - data = [(x % 2, x) for x in range(100)] - - # wrong create combiner - m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge value - m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge combiners - m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - - -class SorterTests(unittest.TestCase): - def test_in_memory_sort(self): - l = list(range(1024)) - random.shuffle(l) - sorter = ExternalSorter(1024) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - - def test_external_sort(self): - class CustomizedSorter(ExternalSorter): - def _next_limit(self): - return self.memory_limit - l = list(range(1024)) - random.shuffle(l) - sorter = CustomizedSorter(1) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertGreater(shuffle.DiskBytesSpilled, 0) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - - def test_external_sort_in_rdd(self): - conf = SparkConf().set("spark.python.worker.memory", "1m") - sc = SparkContext(conf=conf) - l = list(range(10240)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class SerializationTestCase(unittest.TestCase): - - def test_namedtuple(self): - from collections import namedtuple - from pickle import dumps, loads - P = namedtuple("P", "x y") - p1 = P(1, 3) - p2 = loads(dumps(p1, 2)) - self.assertEqual(p1, p2) - - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) - p3 = P2(1, 3) - self.assertEqual(p1, p3) - - def test_itemgetter(self): - from operator import itemgetter - ser = CloudPickleSerializer() - d = range(10) - getter = itemgetter(1) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - getter = itemgetter(0, 3) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - def test_function_module_name(self): - ser = CloudPickleSerializer() - func = lambda x: x - func2 = ser.loads(ser.dumps(func)) - self.assertEqual(func.__module__, func2.__module__) - - def test_attrgetter(self): - from operator import attrgetter - ser = CloudPickleSerializer() - - class C(object): - def __getattr__(self, item): - return item - d = C() - getter = attrgetter("a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("a", "b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - d.e = C() - getter = attrgetter("e.a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("e.a", "e.b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - # Regression test for SPARK-3415 - def test_pickling_file_handles(self): - # to be corrected with SPARK-11160 - if not xmlrunner: - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) - - def test_func_globals(self): - - class Unpicklable(object): - def __reduce__(self): - raise Exception("not picklable") - - global exit - exit = Unpicklable() - - ser = CloudPickleSerializer() - self.assertRaises(Exception, lambda: ser.dumps(exit)) - - def foo(): - sys.exit(0) - - self.assertTrue("exit" in foo.__code__.co_names) - ser.dumps(foo) - - def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) - try: - from StringIO import StringIO - except ImportError: - from io import BytesIO as StringIO - io = StringIO() - ser.dump_stream(["abc", u"123", range(5)], io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) - ser.dump_stream(range(1000), io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) - io.close() - - def test_hash_serializer(self): - hash(NoOpSerializer()) - hash(UTF8Deserializer()) - hash(PickleSerializer()) - hash(MarshalSerializer()) - hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) - hash(AutoBatchedSerializer(MarshalSerializer())) - hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) - - -class QuietTest(object): - def __init__(self, sc): - self.log4j = sc._jvm.org.apache.log4j - - def __enter__(self): - self.old_level = self.log4j.LogManager.getRootLogger().getLevel() - self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.log4j.LogManager.getRootLogger().setLevel(self.old_level) - - -class PySparkTestCase(unittest.TestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name) - - def tearDown(self): - self.sc.stop() - sys.path = self._old_sys_path - - -class ReusedPySparkTestCase(unittest.TestCase): - - @classmethod - def conf(cls): - """ - Override this in subclasses to supply a more specific conf - """ - return SparkConf() - - @classmethod - def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - - -class CheckpointTests(ReusedPySparkTestCase): - - def setUp(self): - self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) - - def tearDown(self): - shutil.rmtree(self.checkpointDir.name) - - def test_basic_checkpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual("file:" + self.checkpointDir.name, - os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) - - def test_checkpoint_and_restore(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: [x]) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - flatMappedRDD.count() # forces a checkpoint to be computed - time.sleep(1) # 1 second - - self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), - flatMappedRDD._jrdd_deserializer) - self.assertEqual([1, 2, 3, 4], recovered.collect()) - - -class LocalCheckpointTests(ReusedPySparkTestCase): - - def test_basic_localcheckpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) - - flatMappedRDD.localCheckpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - - -class AddFileTests(PySparkTestCase): - - def test_add_py_file(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this job fails due to `userlibrary` not being on the Python path: - # disable logging in log4j temporarily - def func(x): - from userlibrary import UserClass - return UserClass().hello() - with QuietTest(self.sc): - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - - # Add the file, so the job should now succeed: - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - res = self.sc.parallelize(range(2)).map(func).first() - self.assertEqual("Hello World!", res) - - def test_add_file_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - self.sc.addFile(path) - download_path = SparkFiles.get("hello.txt") - self.assertNotEqual(path, download_path) - with open(download_path) as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - - def test_add_file_recursively_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello") - self.sc.addFile(path, True) - download_path = SparkFiles.get("hello") - self.assertNotEqual(path, download_path) - with open(download_path + "/hello.txt") as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - with open(download_path + "/sub_hello/sub_hello.txt") as test_file: - self.assertEqual("Sub Hello World!\n", test_file.readline()) - - def test_add_py_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlibrary import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - from userlibrary import UserClass - self.assertEqual("Hello World!", UserClass().hello()) - - def test_add_egg_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlib import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") - self.sc.addPyFile(path) - from userlib import UserClass - self.assertEqual("Hello World from inside a package!", UserClass().hello()) - - def test_overwrite_system_module(self): - self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - - import SimpleHTTPServer - self.assertEqual("My Server", SimpleHTTPServer.__name__) - - def func(x): - import SimpleHTTPServer - return SimpleHTTPServer.__name__ - - self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) - - -class TaskContextTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - # Allow retries even though they are normally disabled in local mode - self.sc = SparkContext('local[4, 2]', class_name) - - def test_stage_id(self): - """Test the stage ids are available and incrementing as expected.""" - rdd = self.sc.parallelize(range(10)) - stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - # Test using the constructor directly rather than the get() - stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] - self.assertEqual(stage1 + 1, stage2) - self.assertEqual(stage1 + 2, stage3) - self.assertEqual(stage2 + 1, stage3) - - def test_partition_id(self): - """Test the partition id.""" - rdd1 = self.sc.parallelize(range(10), 1) - rdd2 = self.sc.parallelize(range(10), 2) - pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() - pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() - self.assertEqual(0, pids1[0]) - self.assertEqual(0, pids1[9]) - self.assertEqual(0, pids2[0]) - self.assertEqual(1, pids2[9]) - - def test_attempt_number(self): - """Verify the attempt numbers are correctly reported.""" - rdd = self.sc.parallelize(range(10)) - # Verify a simple job with no failures - attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() - map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) - - def fail_on_first(x): - """Fail on the first attempt so we get a positive attempt number""" - tc = TaskContext.get() - attempt_number = tc.attemptNumber() - partition_id = tc.partitionId() - attempt_id = tc.taskAttemptId() - if attempt_number == 0 and partition_id == 0: - raise Exception("Failing on first attempt") - else: - return [x, partition_id, attempt_number, attempt_id] - result = rdd.map(fail_on_first).collect() - # We should re-submit the first partition to it but other partitions should be attempt 0 - self.assertEqual([0, 0, 1], result[0][0:3]) - self.assertEqual([9, 3, 0], result[9][0:3]) - first_partition = filter(lambda x: x[1] == 0, result) - map(lambda x: self.assertEqual(1, x[2]), first_partition) - other_partitions = filter(lambda x: x[1] != 0, result) - map(lambda x: self.assertEqual(0, x[2]), other_partitions) - # The task attempt id should be different - self.assertTrue(result[0][3] != result[9][3]) - - def test_tc_on_driver(self): - """Verify that getting the TaskContext on the driver returns None.""" - tc = TaskContext.get() - self.assertTrue(tc is None) - - def test_get_local_property(self): - """Verify that local properties set on the driver are available in TaskContext.""" - key = "testkey" - value = "testvalue" - self.sc.setLocalProperty(key, value) - try: - rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] - self.assertEqual(prop1, value) - prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] - self.assertTrue(prop2 is None) - finally: - self.sc.setLocalProperty(key, None) - - def test_barrier(self): - """ - Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks - within a stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - tc.barrier() - return time.time() - - times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() - self.assertTrue(max(times) - min(times) < 1) - - def test_barrier_infos(self): - """ - Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the - barrier stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() - .getTaskInfos()).collect() - self.assertTrue(len(taskInfos) == 4) - self.assertTrue(len(taskInfos[0]) == 4) - - -class RDDTests(ReusedPySparkTestCase): - - def test_range(self): - self.assertEqual(self.sc.range(1, 1).count(), 0) - self.assertEqual(self.sc.range(1, 0, -1).count(), 1) - self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) - - def test_id(self): - rdd = self.sc.parallelize(range(10)) - id = rdd.id() - self.assertEqual(id, rdd.id()) - rdd2 = rdd.map(str).filter(bool) - id2 = rdd2.id() - self.assertEqual(id + 1, id2) - self.assertEqual(id2, rdd2.id()) - - def test_empty_rdd(self): - rdd = self.sc.emptyRDD() - self.assertTrue(rdd.isEmpty()) - - def test_sum(self): - self.assertEqual(0, self.sc.emptyRDD().sum()) - self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) - - def test_to_localiterator(self): - from time import sleep - rdd = self.sc.parallelize([1, 2, 3]) - it = rdd.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it)) - - rdd2 = rdd.repartition(1000) - it2 = rdd2.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it2)) - - def test_save_as_textfile_with_unicode(self): - # Regression test for SPARK-970 - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode("utf-8")) - - def test_save_as_textfile_with_utf8(self): - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x.encode("utf-8")]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode('utf8')) - - def test_transforming_cartesian_result(self): - # Regression test for SPARK-1034 - rdd1 = self.sc.parallelize([1, 2]) - rdd2 = self.sc.parallelize([3, 4]) - cart = rdd1.cartesian(rdd2) - result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() - - def test_transforming_pickle_file(self): - # Regression test for SPARK-2601 - data = self.sc.parallelize([u"Hello", u"World!"]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsPickleFile(tempFile.name) - pickled_file = self.sc.pickleFile(tempFile.name) - pickled_file.map(lambda x: x).collect() - - def test_cartesian_on_textfile(self): - # Regression test for - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - a = self.sc.textFile(path) - result = a.cartesian(a).collect() - (x, y) = result[0] - self.assertEqual(u"Hello World!", x.strip()) - self.assertEqual(u"Hello World!", y.strip()) - - def test_cartesian_chaining(self): - # Tests for SPARK-16589 - rdd = self.sc.parallelize(range(10), 2) - self.assertSetEqual( - set(rdd.cartesian(rdd).cartesian(rdd).collect()), - set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.cartesian(rdd)).collect()), - set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.zip(rdd)).collect()), - set([(x, (y, y)) for x in range(10) for y in range(10)]) - ) - - def test_zip_chaining(self): - # Tests for SPARK-21985 - rdd = self.sc.parallelize('abc', 2) - self.assertSetEqual( - set(rdd.zip(rdd).zip(rdd).collect()), - set([((x, x), x) for x in 'abc']) - ) - self.assertSetEqual( - set(rdd.zip(rdd.zip(rdd)).collect()), - set([(x, (x, x)) for x in 'abc']) - ) - - def test_deleting_input_files(self): - # Regression test for SPARK-1025 - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - def test_sampling_default_seed(self): - # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(xrange(1000), 1) - subset = data.takeSample(False, 10) - self.assertEqual(len(subset), 10) - - def test_aggregate_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregate and treeAggregate to build dict - # representing a counter of ints - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - # Show that single or multiple partitions work - data1 = self.sc.range(10, numSlices=1) - data2 = self.sc.range(10, numSlices=2) - - def seqOp(x, y): - x[y] += 1 - return x - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) - counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) - counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - - ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) - self.assertEqual(counts1, ground_truth) - self.assertEqual(counts2, ground_truth) - self.assertEqual(counts3, ground_truth) - self.assertEqual(counts4, ground_truth) - - def test_aggregate_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that - # contains lists of all values for each key in the original RDD - - # list(range(...)) for Python 3.x compatibility (can't use * operator - # on a range object) - # list(zip(...)) for Python 3.x compatibility (want to parallelize a - # collection, not a zip object) - tuples = list(zip(list(range(10))*2, [1]*20)) - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def seqOp(x, y): - x.append(y) - return x - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.aggregateByKey([], seqOp, comboOp).collect() - values2 = data2.aggregateByKey([], seqOp, comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - ground_truth = [(i, [1]*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_fold_mutable_zero_value(self): - # Test for SPARK-9021; uses fold to merge an RDD of dict counters into - # a single dict - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - counts1 = defaultdict(int, dict((i, 1) for i in range(10))) - counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) - counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) - counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) - all_counts = [counts1, counts2, counts3, counts4] - # Show that single or multiple partitions work - data1 = self.sc.parallelize(all_counts, 1) - data2 = self.sc.parallelize(all_counts, 2) - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - fold1 = data1.fold(defaultdict(int), comboOp) - fold2 = data2.fold(defaultdict(int), comboOp) - - ground_truth = defaultdict(int) - for counts in all_counts: - for key, val in counts.items(): - ground_truth[key] += val - self.assertEqual(fold1, ground_truth) - self.assertEqual(fold2, ground_truth) - - def test_fold_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains - # lists of all values for each key in the original RDD - - tuples = [(i, range(i)) for i in range(10)]*2 - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.foldByKey([], comboOp).collect() - values2 = data2.foldByKey([], comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - # list(range(...)) for Python 3.x compatibility - ground_truth = [(i, list(range(i))*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_aggregate_by_key(self): - data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) - - def seqOp(x, y): - x.add(y) - return x - - def combOp(x, y): - x |= y - return x - - sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) - self.assertEqual(3, len(sets)) - self.assertEqual(set([1]), sets[1]) - self.assertEqual(set([2]), sets[3]) - self.assertEqual(set([1, 3]), sets[5]) - - def test_itemgetter(self): - rdd = self.sc.parallelize([range(10)]) - from operator import itemgetter - self.assertEqual([1], rdd.map(itemgetter(1)).collect()) - self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) - - def test_namedtuple_in_rdd(self): - from collections import namedtuple - Person = namedtuple("Person", "id firstName lastName") - jon = Person(1, "Jon", "Doe") - jane = Person(2, "Jane", "Doe") - theDoes = self.sc.parallelize([jon, jane]) - self.assertEqual([jon, jane], theDoes.collect()) - - def test_large_broadcast(self): - N = 10000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 27MB - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - - def test_unpersist(self): - N = 1000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 3MB - bdata.unpersist() - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - bdata.destroy() - try: - self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - except Exception as e: - pass - else: - raise Exception("job should fail after destroy the broadcast") - - def test_multiple_broadcasts(self): - N = 1 << 21 - b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = list(range(1 << 15)) - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - def test_multithread_broadcast_pickle(self): - import threading - - b1 = self.sc.broadcast(list(range(3))) - b2 = self.sc.broadcast(list(range(3))) - - def f1(): - return b1.value - - def f2(): - return b2.value - - funcs_num_pickled = {f1: None, f2: None} - - def do_pickle(f, sc): - command = (f, None, sc.serializer, sc.serializer) - ser = CloudPickleSerializer() - ser.dumps(command) - - def process_vars(sc): - broadcast_vars = list(sc._pickled_broadcast_vars) - num_pickled = len(broadcast_vars) - sc._pickled_broadcast_vars.clear() - return num_pickled - - def run(f, sc): - do_pickle(f, sc) - funcs_num_pickled[f] = process_vars(sc) - - # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage - do_pickle(f1, self.sc) - - # run all for f2, should only add/count/clear b2 from worker thread local storage - t = threading.Thread(target=run, args=(f2, self.sc)) - t.start() - t.join() - - # count number of vars pickled in main thread, only b1 should be counted and cleared - funcs_num_pickled[f1] = process_vars(self.sc) - - self.assertEqual(funcs_num_pickled[f1], 1) - self.assertEqual(funcs_num_pickled[f2], 1) - self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) - - def test_large_closure(self): - N = 200000 - data = [float(i) for i in xrange(N)] - rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEqual(N, rdd.first()) - # regression test for SPARK-6886 - self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) - - def test_zip_with_different_serializers(self): - a = self.sc.parallelize(range(5)) - b = self.sc.parallelize(range(100, 105)) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) - b = b._reserialize(MarshalSerializer()) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - # regression test for SPARK-4841 - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - t = self.sc.textFile(path) - cnt = t.count() - self.assertEqual(cnt, t.zip(t).count()) - rdd = t.map(str) - self.assertEqual(cnt, t.zip(rdd).count()) - # regression test for bug in _reserializer() - self.assertEqual(cnt, t.zip(rdd).count()) - - def test_zip_with_different_object_sizes(self): - # regress test for SPARK-5973 - a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) - self.assertEqual(10000, a.zip(b).count()) - - def test_zip_with_different_number_of_items(self): - a = self.sc.parallelize(range(5), 2) - # different number of partitions - b = self.sc.parallelize(range(100, 106), 3) - self.assertRaises(ValueError, lambda: a.zip(b)) - with QuietTest(self.sc): - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) - - def test_count_approx_distinct(self): - rdd = self.sc.parallelize(xrange(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) - - rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) - self.assertTrue(18 < rdd.countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) - - self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) - - def test_histogram(self): - # empty - rdd = self.sc.parallelize([]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - self.assertRaises(ValueError, lambda: rdd.histogram(1)) - - # out of range - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) - - # in range with one bucket - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual([4], rdd.histogram([0, 10])[1]) - self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) - - # in range with one bucket exact match - self.assertEqual([4], rdd.histogram([1, 4])[1]) - - # out of range with two buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) - - # out of range with two uneven buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - - # in range with two buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two bucket and None - rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two uneven buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) - - # mixed range with two uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) - - # mixed range with four uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # mixed range with uneven buckets and NaN - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, - 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # out of range with infinite buckets - rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) - - # invalid buckets - self.assertRaises(ValueError, lambda: rdd.histogram([])) - self.assertRaises(ValueError, lambda: rdd.histogram([1])) - self.assertRaises(ValueError, lambda: rdd.histogram(0)) - self.assertRaises(TypeError, lambda: rdd.histogram({})) - - # without buckets - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 4], [4]), rdd.histogram(1)) - - # without buckets single element - rdd = self.sc.parallelize([1]) - self.assertEqual(([1, 1], [1]), rdd.histogram(1)) - - # without bucket no range - rdd = self.sc.parallelize([1] * 4) - self.assertEqual(([1, 1], [4]), rdd.histogram(1)) - - # without buckets basic two - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) - - # without buckets with more requested than elements - rdd = self.sc.parallelize([1, 2]) - buckets = [1 + 0.2 * i for i in range(6)] - hist = [1, 0, 0, 0, 1] - self.assertEqual((buckets, hist), rdd.histogram(5)) - - # invalid RDDs - rdd = self.sc.parallelize([1, float('inf')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - rdd = self.sc.parallelize([float('nan')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - - # string - rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - def test_repartitionAndSortWithinPartitions_asc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) - - def test_repartitionAndSortWithinPartitions_desc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) - self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) - - def test_repartition_no_skewed(self): - num_partitions = 20 - a = self.sc.parallelize(range(int(1000)), 2) - l = a.repartition(num_partitions).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - l = a.coalesce(num_partitions, True).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - - def test_repartition_on_textfile(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - rdd = self.sc.textFile(path) - result = rdd.repartition(1).collect() - self.assertEqual(u"Hello World!", result[0]) - - def test_distinct(self): - rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEqual(rdd.getNumPartitions(), 10) - self.assertEqual(rdd.distinct().count(), 3) - result = rdd.distinct(5) - self.assertEqual(result.getNumPartitions(), 5) - self.assertEqual(result.count(), 3) - - def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "1m") - N = 200001 - kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) - gkv = kv.groupByKey().cache() - self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda kv: kv[0] == 1) - self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) - self.assertEqual([(N // 3, N // 3)], - filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) - result = filtered.collect()[0][1] - self.assertEqual(N // 3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) - - def test_sort_on_empty_rdd(self): - self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) - - def test_sample(self): - rdd = self.sc.parallelize(range(0, 100), 4) - wo = rdd.sample(False, 0.1, 2).collect() - wo_dup = rdd.sample(False, 0.1, 2).collect() - self.assertSetEqual(set(wo), set(wo_dup)) - wr = rdd.sample(True, 0.2, 5).collect() - wr_dup = rdd.sample(True, 0.2, 5).collect() - self.assertSetEqual(set(wr), set(wr_dup)) - wo_s10 = rdd.sample(False, 0.3, 10).collect() - wo_s20 = rdd.sample(False, 0.3, 20).collect() - self.assertNotEqual(set(wo_s10), set(wo_s20)) - wr_s11 = rdd.sample(True, 0.4, 11).collect() - wr_s21 = rdd.sample(True, 0.4, 21).collect() - self.assertNotEqual(set(wr_s11), set(wr_s21)) - - def test_null_in_rdd(self): - jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) - rdd = RDD(jrdd, self.sc, UTF8Deserializer()) - self.assertEqual([u"a", None, u"b"], rdd.collect()) - rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual([b"a", None, b"b"], rdd.collect()) - - def test_multiple_python_java_RDD_conversions(self): - # Regression test for SPARK-5361 - data = [ - (u'1', {u'director': u'David Lean'}), - (u'2', {u'director': u'Andrew Dominik'}) - ] - data_rdd = self.sc.parallelize(data) - data_java_rdd = data_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - # conversion between python and java RDD threw exceptions - data_java_rdd = converted_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - def test_narrow_dependency_in_join(self): - rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) - parted = rdd.partitionBy(2) - self.assertEqual(2, parted.union(parted).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - - tracker = self.sc.statusTracker() - - self.sc.setJobGroup("test1", "test", True) - d = sorted(parted.join(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test1")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test2", "test", True) - d = sorted(parted.join(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test2")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test3", "test", True) - d = sorted(parted.cogroup(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test3")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test4", "test", True) - d = sorted(parted.cogroup(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test4")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - # Regression test for SPARK-6294 - def test_take_on_jrdd(self): - rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) - rdd._jrdd.first() - - def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): - # Regression test for SPARK-5969 - seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence - rdd = self.sc.parallelize(seq) - for ascending in [True, False]: - sort = rdd.sortByKey(ascending=ascending, numPartitions=5) - self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) - sizes = sort.glom().map(len).collect() - for size in sizes: - self.assertGreater(size, 0) - - def test_pipe_functions(self): - data = ['1', '2', '3'] - rdd = self.sc.parallelize(data) - with QuietTest(self.sc): - self.assertEqual([], rdd.pipe('cc').collect()) - self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) - result = rdd.pipe('cat').collect() - result.sort() - for x, y in zip(data, result): - self.assertEqual(x, y) - self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) - self.assertEqual([], rdd.pipe('grep 4').collect()) - - def test_pipe_unicode(self): - # Regression test for SPARK-20947 - data = [u'\u6d4b\u8bd5', '1'] - rdd = self.sc.parallelize(data) - result = rdd.pipe('cat').collect() - self.assertEqual(data, result) - - def test_stopiteration_in_user_code(self): - - def stopit(*x): - raise StopIteration() - - seq_rdd = self.sc.parallelize(range(10)) - keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - msg = "Caught StopIteration thrown from user's code; failing the task" - - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, - seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - - # these methods call the user function both in the driver and in the executor - # the exception raised is different according to where the StopIteration happens - # RuntimeError is raised if in the driver - # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, lambda *x: 1, stopit) - - -class ProfilerTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, conf=conf) - - def test_profiler(self): - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - id, profiler, _ = profilers[0] - stats = profiler.stats() - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - old_stdout = sys.stdout - sys.stdout = io = StringIO() - self.sc.show_profiles() - self.assertTrue("heavy_foo" in io.getvalue()) - sys.stdout = old_stdout - - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - def test_custom_profiler(self): - class TestCustomProfiler(BasicProfiler): - def show(self, id): - self.result = "Custom formatting" - - self.sc.profiler_collector.profiler_cls = TestCustomProfiler - - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - _, profiler, _ = profilers[0] - self.assertTrue(isinstance(profiler, TestCustomProfiler)) - - self.sc.show_profiles() - self.assertEqual("Custom formatting", profiler.result) - - def do_computation(self): - def heavy_foo(x): - for i in range(1 << 18): - x = 1 - - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - - -class ProfilerTests2(unittest.TestCase): - def test_profiler_disabled(self): - sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) - try: - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.show_profiles()) - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.dump_profiles("/tmp/abc")) - finally: - sc.stop() - - -class InputFormatTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", - "org.apache.hadoop.io.DoubleWritable", - "org.apache.hadoop.io.Text").collect()) - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.assertEqual(doubles, ed) - - bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) - ebs = [(1, bytearray('aa', 'utf-8')), - (1, bytearray('aa', 'utf-8')), - (2, bytearray('aa', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (3, bytearray('cc', 'utf-8'))] - self.assertEqual(bytes, ebs) - - text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", - "org.apache.hadoop.io.Text", - "org.apache.hadoop.io.Text").collect()) - et = [(u'1', u'aa'), - (u'1', u'aa'), - (u'2', u'aa'), - (u'2', u'bb'), - (u'2', u'bb'), - (u'3', u'cc')] - self.assertEqual(text, et) - - bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.assertEqual(bools, eb) - - nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.assertEqual(nulls, en) - - maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - for v in maps: - self.assertTrue(v in em) - - # arrays get pickled to tuples by default - tuples = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable").collect()) - et = [(1, ()), - (2, (3.0, 4.0, 5.0)), - (3, (4.0, 5.0, 6.0))] - self.assertEqual(tuples, et) - - # with custom converters, primitive arrays can stay as arrays - arrays = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - ea = [(1, array('d')), - (2, array('d', [3.0, 4.0, 5.0])), - (3, array('d', [4.0, 5.0, 6.0]))] - self.assertEqual(arrays, ea) - - clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable").collect()) - cname = u'org.apache.spark.api.python.TestWritable' - ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), - (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), - (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), - (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), - (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] - self.assertEqual(clazz, ec) - - unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - ).collect()) - self.assertEqual(unbatched_clazz, ec) - - def test_oldhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=oldconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=newconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newolderror(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.sequenceFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.NotValidWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - maps = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - keyConverter="org.apache.spark.api.python.TestInputKeyConverter", - valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) - em = [(u'\x01', []), - (u'\x01', [3.0]), - (u'\x02', [1.0]), - (u'\x02', [1.0]), - (u'\x03', [2.0])] - self.assertEqual(maps, em) - - def test_binary_files(self): - path = os.path.join(self.tempdir.name, "binaryfiles") - os.mkdir(path) - data = b"short binary data" - with open(os.path.join(path, "part-0000"), 'wb') as f: - f.write(data) - [(p, d)] = self.sc.binaryFiles(path).collect() - self.assertTrue(p.endswith("part-0000")) - self.assertEqual(d, data) - - def test_binary_records(self): - path = os.path.join(self.tempdir.name, "binaryrecords") - os.mkdir(path) - with open(os.path.join(path, "part-0000"), 'w') as f: - for i in range(100): - f.write('%04d' % i) - result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(list(range(100)), result) - - -class OutputFormatTests(ReusedPySparkTestCase): - - def setUp(self): - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - - def tearDown(self): - shutil.rmtree(self.tempdir.name, ignore_errors=True) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") - ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) - self.assertEqual(ints, ei) - - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") - doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) - self.assertEqual(doubles, ed) - - ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] - self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") - bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) - self.assertEqual(bytes, ebs) - - et = [(u'1', u'aa'), - (u'2', u'bb'), - (u'3', u'cc')] - self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") - text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) - self.assertEqual(text, et) - - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") - bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) - self.assertEqual(bools, eb) - - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") - nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) - self.assertEqual(nulls, en) - - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() - for v in maps: - self.assertTrue(v, em) - - def test_oldhadoop(self): - basepath = self.tempdir.name - dict_data = [(1, {}), - (1, {"row1": 1.0}), - (2, {"row2": 2.0})] - self.sc.parallelize(dict_data).saveAsHadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable") - result = self.sc.hadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - for v in result: - self.assertTrue(v, dict_data) - - conf = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" - } - self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} - result = self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect() - for v in result: - self.assertTrue(v, dict_data) - - def test_newhadoop(self): - basepath = self.tempdir.name - data = [(1, ""), - (1, "a"), - (2, "bcdf")] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - self.assertEqual(result, data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=input_conf).collect()) - self.assertEqual(new_dataset, data) - - @unittest.skipIf(sys.version >= "3", "serialize of array") - def test_newhadoop_with_array(self): - basepath = self.tempdir.name - # use custom ArrayWritable types and converters to handle arrays - array_data = [(1, array('d')), - (1, array('d', [1.0, 2.0, 3.0])), - (2, array('d', [3.0, 4.0, 5.0]))] - self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - self.assertEqual(result, array_data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( - conf, - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", - conf=input_conf).collect()) - self.assertEqual(new_dataset, array_data) - - def test_newolderror(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/newolderror/saveAsHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/newolderror/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/badinputs/saveAsHadoopFile/", - "org.apache.hadoop.mapred.NotValidOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/badinputs/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - data = [(1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/converters/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", - valueConverter="org.apache.spark.api.python.TestOutputValueConverter") - converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) - expected = [(u'1', 3.0), - (u'2', 1.0), - (u'3', 2.0)] - self.assertEqual(converted, expected) - - def test_reserialization(self): - basepath = self.tempdir.name - x = range(1, 5) - y = range(1001, 1005) - data = list(zip(x, y)) - rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) - rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") - result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) - self.assertEqual(result1, data) - - rdd.saveAsHadoopFile( - basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") - result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) - self.assertEqual(result2, data) - - rdd.saveAsNewAPIHadoopFile( - basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) - self.assertEqual(result3, data) - - conf4 = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} - rdd.saveAsHadoopDataset(conf4) - result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) - self.assertEqual(result4, data) - - conf5 = {"mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" - } - rdd.saveAsNewAPIHadoopDataset(conf5) - result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) - self.assertEqual(result5, data) - - def test_malformed_RDD(self): - basepath = self.tempdir.name - # non-batch-serialized RDD[[(K, V)]] should be rejected - data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] - rdd = self.sc.parallelize(data, len(data)) - self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( - basepath + "/malformed/sequence")) - - -class DaemonTests(unittest.TestCase): - def connect(self, port): - from socket import socket, AF_INET, SOCK_STREAM - sock = socket(AF_INET, SOCK_STREAM) - sock.connect(('127.0.0.1', port)) - # send a split index of -1 to shutdown the worker - sock.send(b"\xFF\xFF\xFF\xFF") - sock.close() - return True - - def do_termination_test(self, terminator): - from subprocess import Popen, PIPE - from errno import ECONNREFUSED - - # start daemon - daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") - daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) - - # read the port number - port = read_int(daemon.stdout) - - # daemon should accept connections - self.assertTrue(self.connect(port)) - - # request shutdown - terminator(daemon) - time.sleep(1) - - # daemon should no longer accept connections - try: - self.connect(port) - except EnvironmentError as exception: - self.assertEqual(exception.errno, ECONNREFUSED) - else: - self.fail("Expected EnvironmentError to be raised") - - def test_termination_stdin(self): - """Ensure that daemon and workers terminate when stdin is closed.""" - self.do_termination_test(lambda daemon: daemon.stdin.close()) - - def test_termination_sigterm(self): - """Ensure that daemon and workers terminate on SIGTERM.""" - from signal import SIGTERM - self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) - - -class WorkerTests(ReusedPySparkTestCase): - def test_cancel_task(self): - temp = tempfile.NamedTemporaryFile(delete=True) - temp.close() - path = temp.name - - def sleep(x): - import os - import time - with open(path, 'w') as f: - f.write("%d %d" % (os.getppid(), os.getpid())) - time.sleep(100) - - # start job in background thread - def run(): - try: - self.sc.parallelize(range(1), 1).foreach(sleep) - except Exception: - pass - import threading - t = threading.Thread(target=run) - t.daemon = True - t.start() - - daemon_pid, worker_pid = 0, 0 - while True: - if os.path.exists(path): - with open(path) as f: - data = f.read().split(' ') - daemon_pid, worker_pid = map(int, data) - break - time.sleep(0.1) - - # cancel jobs - self.sc.cancelAllJobs() - t.join() - - for i in range(50): - try: - os.kill(worker_pid, 0) - time.sleep(0.1) - except OSError: - break # worker was killed - else: - self.fail("worker has not been killed after 5 seconds") - - try: - os.kill(daemon_pid, 0) - except OSError: - self.fail("daemon had been killed") - - # run a normal job - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_exception(self): - def raise_exception(_): - raise Exception() - rdd = self.sc.parallelize(xrange(100), 1) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_jvm_exception(self): - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name, 1) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_accumulator_when_reuse_worker(self): - from pyspark.accumulators import INT_ACCUMULATOR_PARAM - acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) - self.assertEqual(sum(range(100)), acc1.value) - - acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) - self.assertEqual(sum(range(100)), acc2.value) - self.assertEqual(sum(range(100)), acc1.value) - - def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(xrange(100000), 1) - self.assertEqual(0, rdd.first()) - - def count(): - try: - rdd.count() - except Exception: - pass - - t = threading.Thread(target=count) - t.daemon = True - t.start() - t.join(5) - self.assertTrue(not t.isAlive()) - self.assertEqual(100000, rdd.count()) - - def test_with_different_versions_of_python(self): - rdd = self.sc.parallelize(range(10)) - rdd.count() - version = self.sc.pythonVer - self.sc.pythonVer = "2.0" - try: - with QuietTest(self.sc): - self.assertRaises(Py4JJavaError, lambda: rdd.count()) - finally: - self.sc.pythonVer = version - - -class SparkSubmitTests(unittest.TestCase): - - def setUp(self): - self.programDir = tempfile.mkdtemp() - tmp_dir = tempfile.gettempdir() - self.sparkSubmit = [ - os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), - "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - ] - - def tearDown(self): - shutil.rmtree(self.programDir) - - def createTempFile(self, name, content, dir=None): - """ - Create a temp file with the given name and content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name) - else: - os.makedirs(os.path.join(self.programDir, dir)) - path = os.path.join(self.programDir, dir, name) - with open(path, "w") as f: - f.write(content) - return path - - def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): - """ - Create a zip archive containing a file with the given content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name + ext) - else: - path = os.path.join(self.programDir, dir, zip_name + ext) - zip = zipfile.ZipFile(path, 'w') - zip.writestr(name, content) - zip.close() - return path - - def create_spark_package(self, artifact_name): - group_id, artifact_id, version = artifact_name.split(":") - self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" - | - | - | 4.0.0 - | %s - | %s - | %s - | - """ % (group_id, artifact_id, version)).lstrip(), - os.path.join(group_id, artifact_id, version)) - self.createFileInZip("%s.py" % artifact_id, """ - |def myfunc(x): - | return x + 1 - """, ".jar", os.path.join(group_id, artifact_id, version), - "%s-%s" % (artifact_id, version)) - - def test_single_script(self): - """Submit and test a single script file""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_script_with_local_functions(self): - """Submit and test a single script file calling a global function""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 3 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out.decode('utf-8')) - - def test_module_dependency(self): - """Submit and test a script with a dependency on another module""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_module_dependency_on_cluster(self): - """Submit and test a script with a dependency on another module on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", - "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency(self): - """Submit and test a script with a dependency on a Spark Package""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency_on_cluster(self): - """Submit and test a script with a dependency on a Spark Package on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", - script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_single_script_on_cluster(self): - """Submit and test a single script on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 2 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - # this will fail if you have different spark.executor.memory - # in conf/spark-defaults.conf - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_user_configuration(self): - """Make sure user configuration is respected (SPARK-19307)""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkConf, SparkContext - | - |conf = SparkConf().set("spark.test_config", "1") - |sc = SparkContext(conf = conf) - |try: - | if sc._conf.get("spark.test_config") != "1": - | raise Exception("Cannot find spark.test_config in SparkContext's conf.") - |finally: - | sc.stop() - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local", script], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) - - def test_conda(self): - """Submit and test a single script file via conda""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |sc.addCondaPackages('numpy=1.14.0') - | - |# Ensure numpy is accessible on the driver - |import numpy - |arr = [1, 2, 3] - |def mul2(x): - | # Also ensure numpy accessible from executor - | assert numpy.version.version == "1.14.0" - | return x * 2 - |print(sc.parallelize(arr).map(mul2).collect()) - """) - props = self.createTempFile("properties", """ - |spark.conda.binaryPath {} - |spark.conda.channelUrls https://repo.continuum.io/pkgs/main - |spark.conda.bootstrapPackages python=3.5 - """.format(os.environ["CONDA_BIN"])) - env = dict(os.environ) - del env['PYSPARK_PYTHON'] - del env['PYSPARK_DRIVER_PYTHON'] - proc = subprocess.Popen(self.sparkSubmit + [ - "--properties-file", props, script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env) - out, err = proc.communicate() - if 0 != proc.returncode: - self.fail(("spark-submit was unsuccessful with error code {}\n\n" + - "stdout:\n{}\n\nstderr:\n{}").format(proc.returncode, out, err)) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - -class ContextTests(unittest.TestCase): - - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - - def test_get_or_create(self): - with SparkContext.getOrCreate() as sc: - self.assertTrue(SparkContext.getOrCreate() is sc) - - def test_parallelize_eager_cleanup(self): - with SparkContext() as sc: - temp_files = os.listdir(sc._temp_dir) - rdd = sc.parallelize([0, 1, 2]) - post_parallalize_temp_files = os.listdir(sc._temp_dir) - self.assertEqual(temp_files, post_parallalize_temp_files) - - def test_set_conf(self): - # This is for an internal use case. When there is an existing SparkContext, - # SparkSession's builder needs to set configs into SparkContext's conf. - sc = SparkContext() - sc._conf.set("spark.test.SPARK16224", "SPARK16224") - self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") - sc.stop() - - def test_stop(self): - sc = SparkContext() - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_exception(self): - try: - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - raise Exception() - except: - pass - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_stop(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_progress_api(self): - with SparkContext() as sc: - sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - - def run(): - try: - rdd.count() - except Exception: - pass - t = threading.Thread(target=run) - t.daemon = True - t.start() - # wait for scheduler to start - time.sleep(1) - - tracker = sc.statusTracker() - jobIds = tracker.getJobIdsForGroup('test_progress_api') - self.assertEqual(1, len(jobIds)) - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual(1, len(job.stageIds)) - stage = tracker.getStageInfo(job.stageIds[0]) - self.assertEqual(rdd.getNumPartitions(), stage.numTasks) - - sc.cancelAllJobs() - t.join() - # wait for event listener to update the status - time.sleep(1) - - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual('FAILED', job.status) - self.assertEqual([], tracker.getActiveJobsIds()) - self.assertEqual([], tracker.getActiveStageIds()) - - sc.stop() - - def test_startTime(self): - with SparkContext() as sc: - self.assertGreater(sc.startTime, 0) - - -class ConfTests(unittest.TestCase): - def test_memory_conf(self): - memoryList = ["1T", "1G", "1M", "1024K"] - for memory in memoryList: - sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) - l = list(range(1024)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class KeywordOnlyTests(unittest.TestCase): - class Wrapped(object): - @keyword_only - def set(self, x=None, y=None): - if "x" in self._input_kwargs: - self._x = self._input_kwargs["x"] - if "y" in self._input_kwargs: - self._y = self._input_kwargs["y"] - return x, y - - def test_keywords(self): - w = self.Wrapped() - x, y = w.set(y=1) - self.assertEqual(y, 1) - self.assertEqual(y, w._y) - self.assertIsNone(x) - self.assertFalse(hasattr(w, "_x")) - - def test_non_keywords(self): - w = self.Wrapped() - self.assertRaises(TypeError, lambda: w.set(0, y=1)) - - def test_kwarg_ownership(self): - # test _input_kwargs is owned by each class instance and not a shared static variable - class Setter(object): - @keyword_only - def set(self, x=None, other=None, other_x=None): - if "other" in self._input_kwargs: - self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) - self._x = self._input_kwargs["x"] - - a = Setter() - b = Setter() - a.set(x=1, other=b, other_x=2) - self.assertEqual(a._x, 1) - self.assertEqual(b._x, 2) - - -class UtilTests(PySparkTestCase): - def test_py4j_exception_message(self): - from pyspark.util import _exception_message - - with self.assertRaises(Py4JJavaError) as context: - # This attempts java.lang.String(null) which throws an NPE. - self.sc._jvm.java.lang.String(None) - - self.assertTrue('NullPointerException' in _exception_message(context.exception)) - - def test_parsing_version_string(self): - from pyspark.util import VersionUtils - self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): - - """General PySpark tests that depend on scipy """ - - def test_serialize(self): - from scipy.special import gammaln - x = range(1, 5) - expected = list(map(gammaln, x)) - observed = self.sc.parallelize(x).map(gammaln).collect() - self.assertEqual(expected, observed) - - -@unittest.skipIf(not _have_numpy, "NumPy not installed") -class NumPyTests(PySparkTestCase): - - """General PySpark tests that depend on numpy """ - - def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) - s = x.stats() - self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) - self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) - - stats_dict = s.asDict() - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) - - stats_sample_dict = s.asDict(sample=True) - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) - self.assertSequenceEqual( - [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) - self.assertSequenceEqual( - [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) - - -if __name__ == "__main__": - from pyspark.tests import * - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 5cb8f948de637..24aebef8633ad 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -189,6 +189,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Seq.empty), "prefix", "appId", + None, Map.empty, Map.empty, Map.empty, @@ -202,6 +203,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, + LOCAL_FILES_STEP_TYPE, MOUNT_VOLUMES_STEP_TYPE, DRIVER_CMD_STEP_TYPE) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 046da80031451..787b8cc52021e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -37,14 +37,10 @@ RUN set -ex && \ COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -<<<<<<< HEAD -COPY ${img_path}/spark/entrypoint.sh /opt/ -======= COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ->>>>>>> 87bd9c75df ENV SPARK_HOME /opt/spark diff --git a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile index 2e0613cd2a826..31ec83d3db601 100644 --- a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile +++ b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile @@ -17,11 +17,6 @@ FROM fabric8/java-centos-openjdk8-jdk:latest -ARG spark_jars=jars -ARG example_jars=examples/jars -ARG img_path=kubernetes/dockerfiles -ARG k8s_tests=kubernetes/tests - # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -39,10 +34,13 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY ${spark_jars} /opt/spark/jars +COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index c8a68da151a03..17dd317f63b70 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -93,11 +93,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1264 bytes, 3 rows -======= Partition Statistics [not included in comparison] bytes, 3 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -132,11 +128,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1264 bytes, 3 rows -======= Partition Statistics [not included in comparison] bytes, 3 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -163,11 +155,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1278 bytes, 4 rows -======= Partition Statistics [not included in comparison] bytes, 4 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -202,11 +190,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1264 bytes, 3 rows -======= Partition Statistics [not included in comparison] bytes, 3 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -233,11 +217,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1278 bytes, 4 rows -======= Partition Statistics [not included in comparison] bytes, 4 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -264,11 +244,7 @@ Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 Created Time [not included in comparison] Last Access [not included in comparison] -<<<<<<< HEAD -Partition Statistics 1250 bytes, 2 rows -======= Partition Statistics [not included in comparison] bytes, 2 rows ->>>>>>> 87bd9c75df # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0bbb27760789a..1bfe3818742db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -85,7 +85,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } -<<<<<<< HEAD test("union all") { val unionDF = testData.union(testData).union(testData) .union(testData).union(testData) @@ -209,8 +208,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } -======= ->>>>>>> 87bd9c75df test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index c28ebbea4d41b..d35268580ba38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -506,11 +506,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized -<<<<<<< HEAD - assert(inMemoryRelation.computeStats().sizeInBytes === 848) -======= - assert(inMemoryRelation.computeStats().sizeInBytes === 868) ->>>>>>> 87bd9c75df + assert(inMemoryRelation.computeStats().sizeInBytes === 916) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -523,11 +519,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated -<<<<<<< HEAD - assert(inMemoryRelation2.computeStats().sizeInBytes === 848) -======= - assert(inMemoryRelation2.computeStats().sizeInBytes === 868) ->>>>>>> 87bd9c75df + assert(inMemoryRelation2.computeStats().sizeInBytes === 916) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index b9f4945c71964..1229d31f4cc07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -45,11 +45,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { import testImplicits._ Seq(1.0, 0.5).foreach { compressionFactor => withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, -<<<<<<< HEAD - "spark.sql.autoBroadcastJoinThreshold" -> "424") { -======= - "spark.sql.autoBroadcastJoinThreshold" -> "434") { ->>>>>>> 87bd9c75df + "spark.sql.autoBroadcastJoinThreshold" -> "458") { withTempPath { workDir => // the file size is 740 bytes val workDirPath = workDir.getAbsolutePath From f97d5d4052fa8ec715f874d41078d0f5b21be49c Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 1 Feb 2019 09:58:29 +0000 Subject: [PATCH 143/145] Revert "[SPARK-25867][ML] Remove KMeans computeCost" This reverts commit dd8c179c28c5df20210b70a69d93d866ccaca4cc. --- .../org/apache/spark/ml/clustering/KMeans.scala | 16 ++++++++++++++++ .../apache/spark/ml/clustering/KMeansSuite.scala | 12 +++++++----- project/MimaExcludes.scala | 3 --- python/pyspark/ml/clustering.py | 16 ++++++++++++++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2eed84d51782a..5d02305aafdda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -143,6 +143,22 @@ class KMeansModel private[ml] ( @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + * + * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator + * instead. You can also get the cost on the training dataset in the summary. + */ + @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + + "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) + } + /** * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4f47d91f0d0d5..ccbceab53bb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -117,6 +117,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusters === Set(0, 1, 2, 3, 4)) } + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) // Check validity of model summary @@ -131,6 +132,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } assert(summary.cluster.columns === Array(predictionColName)) assert(summary.trainingCost < 0.1) + assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) @@ -199,15 +201,15 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } test("KMean with Array input") { - def trainAndGetCost(dataset: Dataset[_]): Double = { + def trainAndComputeCost(dataset: Dataset[_]): Double = { val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) - model.summary.trainingCost + model.computeCost(dataset) } val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) - val trueCost = trainAndGetCost(newDataset) - val doubleArrayCost = trainAndGetCost(newDatasetD) - val floatArrayCost = trainAndGetCost(newDatasetF) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) // checking the cost is fine enough as a sanity check assert(trueCost ~== doubleArrayCost absTol 1e-6) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eacda42813130..2f290e26b7f83 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,9 +36,6 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-25867] Remove KMeans computeCost - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - // [SPARK-26127] Remove deprecated setters from tree regression and classification models ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index d0b507ec5dad4..aaeeeb82d3d86 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -335,6 +335,20 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + + ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. + You can also get the cost on the training dataset in the summary. + """ + warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " + "instead. You can also get the cost on the training dataset in the summary.", + DeprecationWarning) + return self._call_java("computeCost", dataset) + @property @since("2.1.0") def hasSummary(self): @@ -373,6 +387,8 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.0 >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction From c3524ded0bb713674b46db3cc9063f72b9eb87fb Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 1 Feb 2019 09:58:51 +0000 Subject: [PATCH 144/145] Revert "[SPARK-26127][ML] Remove deprecated setters from tree regression and classification models" This reverts commit 4aa9ccbde7870fb2750712e9e38e6aad740e0770. --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 26 ++--- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 27 ++--- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ++++++++++++++++++ project/MimaExcludes.scala | 74 +----------- 8 files changed, 178 insertions(+), 138 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bcf89766b0873..6648e78d8eafa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -55,27 +55,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,15 +87,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) override protected def train( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 09a9df6d15ece..2c4186a13d8f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -69,27 +69,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -101,7 +101,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -110,7 +110,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = { + override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -119,25 +119,25 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) + override def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - def setStepSize(value: Double): this.type = set(stepSize, value) + override def setStepSize(value: Double): this.type = set(stepSize, value) /** @group setParam */ @Since("2.3.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0a3bfd1f85e08..7598a28b6f89d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -57,27 +57,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -89,31 +89,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) + override def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index faadc4d7b4ccc..c9de85de42fa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -54,27 +54,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9b386ef5eed8f..88dee2507bf7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -68,27 +69,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -100,7 +101,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -109,7 +110,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = { + override def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -118,21 +119,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) + override def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - def setStepSize(value: Double): this.type = set(stepSize, value) + override def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: @@ -142,7 +143,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("2.3.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index afa9a646412b3..a548ec537bb44 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -56,27 +56,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) + override def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - def setImpurity(value: String): this.type = set(impurity, value) + override def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - def setSeed(value: Long): this.type = set(seed, value) + override def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) + override def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - def setFeatureSubsetStrategy(value: String): this.type = + override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c06c68d44ae1c..f1e3836ebe476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -110,24 +110,80 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxBins(value: Int): this.type = set(maxBins, value) + /** @group getParam */ final def getMaxBins: Int = $(maxBins) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group expertSetParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group expertSetParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -170,6 +226,13 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setImpurity(value: String): this.type = set(impurity, value) + /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -210,6 +273,13 @@ private[ml] trait HasVarianceImpurity extends Params { setDefault(impurity -> "variance") + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setImpurity(value: String): this.type = set(impurity, value) + /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -276,6 +346,13 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -329,6 +406,13 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(featureSubsetStrategy -> "auto") + /** + * @deprecated This method is deprecated and will be removed in 3.0.0 + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -356,6 +440,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setNumTrees(value: Int): this.type = set(numTrees, value) + /** @group getParam */ final def getNumTrees: Int = $(numTrees) } @@ -400,6 +491,13 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @Since("2.4.0") final def getValidationTol: Double = $(validationTol) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -410,6 +508,13 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setStepSize(value: Double): this.type = set(stepSize, value) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2f290e26b7f83..050c65efbe164 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,76 +36,6 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-26127] Remove deprecated setters from tree regression and classification models - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setNumTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setNumTrees"), - // [SPARK-26124] Update plugins, including MiMa ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), @@ -120,11 +50,15 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), From 1f9772d96c0e8bb17fa92cbe3d8367646c7e07f8 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 1 Feb 2019 10:27:16 +0000 Subject: [PATCH 145/145] Add note about reverts --- FORK.md | 4 +++- project/MimaExcludes.scala | 7 ------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/FORK.md b/FORK.md index 0f99790529174..055f925abd879 100644 --- a/FORK.md +++ b/FORK.md @@ -25,4 +25,6 @@ # Reverted * [SPARK-25908](https://issues.apache.org/jira/browse/SPARK-25908) - Removal of `monotonicall_increasing_id`, `toDegree`, `toRadians`, `approxCountDistinct`, `unionAll` -* [SPARK-25862](https://issues.apache.org/jira/browse/SPARK-25862) - Removal of `unboundedPreceding`, `unboundedFollowing`, `currentRow` \ No newline at end of file +* [SPARK-25862](https://issues.apache.org/jira/browse/SPARK-25862) - Removal of `unboundedPreceding`, `unboundedFollowing`, `currentRow` +* [SPARK-26127](https://issues.apache.org/jira/browse/SPARK-26127) - Removal of deprecated setters from tree regression and classification models +* [SPARK-25867](https://issues.apache.org/jira/browse/SPARK-25867) - Removal of KMeans computeCost \ No newline at end of file diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 050c65efbe164..842730e7deb13 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -114,13 +114,6 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), - // [SPARK-25862][SQL] Remove rangeBetween APIs introduced in SPARK-21608 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedFollowing"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedPreceding"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.currentRow"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.Window.rangeBetween"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.rangeBetween"), - // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"),