From 804949d519e2caa293a409d84b4e6190c1105444 Mon Sep 17 00:00:00 2001
From: Yin Huai
Date: Sun, 8 Feb 2015 14:55:07 -0800
Subject: [PATCH 001/817] [SQL] Set sessionState in QueryExecution.
This PR sets the SessionState in HiveContext's QueryExecution. So, we can make sure that SessionState.get can return the SessionState every time.
Author: Yin Huai
Closes #4445 from yhuai/setSessionState and squashes the following commits:
769c9f1 [Yin Huai] Remove unused import.
439f329 [Yin Huai] Try again.
427a0c9 [Yin Huai] Set SessionState everytime when we create a QueryExecution in HiveContext.
a3b7793 [Yin Huai] Set sessionState when dealing with CreateTableAsSelect.
---
.../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ad37b7d0e6f59..2c00659496972 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -424,6 +424,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/** Extends QueryExecution with hive specific features. */
protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
extends super.QueryExecution(logicalPlan) {
+ // Like what we do in runHive, makes sure the session represented by the
+ // `sessionState` field is activated.
+ if (SessionState.get() != sessionState) {
+ SessionState.start(sessionState)
+ }
/**
* Returns the result as a hive compatible sequence of strings. For native commands, the
From 5c299c58fb9a5434a40be82150d4725bba805adf Mon Sep 17 00:00:00 2001
From: Xiangrui Meng
Date: Sun, 8 Feb 2015 16:26:20 -0800
Subject: [PATCH 002/817] [SPARK-5598][MLLIB] model save/load for ALS
following #4233. jkbradley
Author: Xiangrui Meng
Closes #4422 from mengxr/SPARK-5598 and squashes the following commits:
a059394 [Xiangrui Meng] SaveLoad not extending Loader
14b7ea6 [Xiangrui Meng] address comments
f487cb2 [Xiangrui Meng] add unit tests
62fc43c [Xiangrui Meng] implement save/load for MFM
---
.../spark/mllib/recommendation/ALS.scala | 2 +-
.../MatrixFactorizationModel.scala | 82 ++++++++++++++++++-
.../MatrixFactorizationModelSuite.scala | 19 +++++
3 files changed, 100 insertions(+), 3 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 4bb28d1b1e071..caacab943030b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.recommendation
import org.apache.spark.Logging
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.recommendation.{ALS => NewALS}
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index ed2f8b41bcae5..9ff06ac362a31 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,17 @@
package org.apache.spark.mllib.recommendation
+import java.io.IOException
import java.lang.{Integer => JavaInteger}
+import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.storage.StorageLevel
/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
- val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
+ val productFeatures: RDD[(Int, Array[Double])])
+ extends Saveable with Serializable with Logging {
require(rank > 0)
validateFeatures("User", userFeatures)
@@ -125,6 +130,12 @@ class MatrixFactorizationModel(
recommend(productFeatures.lookup(product).head, userFeatures, num)
.map(t => Rating(t._1, product, t._2))
+ protected override val formatVersion: String = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
+ }
+
private def recommend(
recommendToFeatures: Array[Double],
recommendableFeatures: RDD[(Int, Array[Double])],
@@ -136,3 +147,70 @@ class MatrixFactorizationModel(
scored.top(num)(Ordering.by(_._2))
}
}
+
+object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
+
+ import org.apache.spark.mllib.util.Loader._
+
+ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, formatVersion) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path)
+ case _ =>
+ throw new IOException("MatrixFactorizationModel.load did not recognize model with" +
+ s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private[recommendation]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private[recommendation]
+ val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
+
+ /**
+ * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
+ * product features are saved under `data/products`.
+ */
+ def save(model: MatrixFactorizationModel, path: String): Unit = {
+ val sc = model.userFeatures.sparkContext
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits.createDataFrame
+ val metadata = (thisClassName, thisFormatVersion, model.rank)
+ val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+ model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
+ model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ val sqlContext = new SQLContext(sc)
+ val (className, formatVersion, metadata) = loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+ val rank = metadata.select("rank").first().getInt(0)
+ val userFeatures = sqlContext.parquetFile(userPath(path))
+ .map { case Row(id: Int, features: Seq[Double]) =>
+ (id, features.toArray)
+ }
+ val productFeatures = sqlContext.parquetFile(productPath(path))
+ .map { case Row(id: Int, features: Seq[Double]) =>
+ (id, features.toArray)
+ }
+ new MatrixFactorizationModel(rank, userFeatures, productFeatures)
+ }
+
+ private def userPath(path: String): String = {
+ new Path(dataPath(path), "user").toUri.toString
+ }
+
+ private def productPath(path: String): String = {
+ new Path(dataPath(path), "product").toUri.toString
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index b9caecc904a23..9801e87576744 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
@@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}
+
+ test("save/load") {
+ val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
+ features.mapValues(_.toSeq).collect().toSet
+ }
+ try {
+ model.save(sc, path)
+ val newModel = MatrixFactorizationModel.load(sc, path)
+ assert(newModel.rank === rank)
+ assert(collect(newModel.userFeatures) === collect(userFeatures))
+ assert(collect(newModel.productFeatures) === collect(prodFeatures))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
From 56aff4bd6c7c9d18f4f962025708f20a4a82dcf0 Mon Sep 17 00:00:00 2001
From: Sam Halliday
Date: Sun, 8 Feb 2015 16:34:26 -0800
Subject: [PATCH 003/817] SPARK-5665 [DOCS] Update netlib-java documentation
I am the author of netlib-java and I found this documentation to be out of date. Some main points:
1. Breeze has not depended on jBLAS for some time
2. netlib-java provides a pure JVM implementation as the fallback (the original docs did not appear to be aware of this, claiming that gfortran was necessary)
3. The licensing issue is not just about LGPL: optimised natives have proprietary licenses. Building with the LGPL flag turned on really doesn't help you get past this.
4. I really think it's best to direct people to my detailed setup guide instead of trying to compress it into one sentence. It is different for each architecture, each OS, and for each backend.
I hope this helps to clear things up :smile:
Author: Sam Halliday
Author: Sam Halliday
Closes #4448 from fommil/patch-1 and squashes the following commits:
18cda11 [Sam Halliday] remove link to skillsmatters at request of @mengxr
a35e4a9 [Sam Halliday] reword netlib-java/breeze docs
---
docs/mllib-guide.md | 41 ++++++++++++++++++++++++-----------------
1 file changed, 24 insertions(+), 17 deletions(-)
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 7779fbc9c49e4..3d32d03e35c62 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -56,25 +56,32 @@ See the **[spark.ml programming guide](ml-guide.html)** for more information on
# Dependencies
-MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/),
-which depends on [netlib-java](https://github.com/fommil/netlib-java),
-and [jblas](https://github.com/mikiobraun/jblas).
-`netlib-java` and `jblas` depend on native Fortran routines.
-You need to install the
+MLlib uses the linear algebra package
+[Breeze](http://www.scalanlp.org/), which depends on
+[netlib-java](https://github.com/fommil/netlib-java) for optimised
+numerical processing. If natives are not available at runtime, you
+will see a warning message and a pure JVM implementation will be used
+instead.
+
+To learn more about the benefits and background of system optimised
+natives, you may wish to watch Sam Halliday's ScalaX talk on
+[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)).
+
+Due to licensing issues with runtime proprietary binaries, we do not
+include `netlib-java`'s native proxies by default. To configure
+`netlib-java` / Breeze to use system optimised binaries, include
+`com.github.fommil.netlib:all:1.1.2` (or build Spark with
+`-Pnetlib-lgpl`) as a dependency of your project and read the
+[netlib-java](https://github.com/fommil/netlib-java) documentation for
+your platform's additional installation instructions.
+
+MLlib also uses [jblas](https://github.com/mikiobraun/jblas) which
+will require you to install the
[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries)
if it is not already present on your nodes.
-MLlib will throw a linking error if it cannot detect these libraries automatically.
-Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's
-dependency set under default settings.
-If no native library is available at runtime, you will see a warning message.
-To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or
-include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project.
-If you want to use optimized BLAS/LAPACK libraries such as
-[OpenBLAS](http://www.openblas.net/), please link its shared libraries to
-`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively.
-BLAS/LAPACK libraries on worker nodes should be built without multithreading.
-
-To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer.
+
+To use MLlib in Python, you will need [NumPy](http://www.numpy.org)
+version 1.4 or newer.
---
From a052ed42501fee3641348337505b6176426653c4 Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Sun, 8 Feb 2015 18:56:51 -0800
Subject: [PATCH 004/817] [SPARK-5643][SQL] Add a show method to print the
content of a DataFrame in tabular format.
An example:
```
year month AVG('Adj Close) MAX('Adj Close)
1980 12 0.503218 0.595103
1981 01 0.523289 0.570307
1982 02 0.436504 0.475256
1983 03 0.410516 0.442194
1984 04 0.450090 0.483521
```
Author: Reynold Xin
Closes #4416 from rxin/SPARK-5643 and squashes the following commits:
d0e0d6e [Reynold Xin] [SQL] Minor update to data source and statistics documentation.
269da83 [Reynold Xin] Updated isLocal comment.
2cf3c27 [Reynold Xin] Moved logic into optimizer.
1a04d8b [Reynold Xin] [SPARK-5643][SQL] Add a show method to print the content of a DataFrame in columnar format.
---
.../sql/catalyst/optimizer/Optimizer.scala | 18 +++++-
.../catalyst/plans/logical/LogicalPlan.scala | 7 ++-
.../ConvertToLocalRelationSuite.scala | 57 +++++++++++++++++++
.../org/apache/spark/sql/DataFrame.scala | 21 ++++++-
.../org/apache/spark/sql/DataFrameImpl.scala | 41 +++++++++++--
.../apache/spark/sql/IncomputableColumn.scala | 6 +-
.../spark/sql/execution/basicOperators.scala | 7 +--
.../apache/spark/sql/sources/interfaces.scala | 15 +++--
8 files changed, 151 insertions(+), 21 deletions(-)
create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8c8f2896eb99b..3bc48c95c5653 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -50,7 +50,9 @@ object DefaultOptimizer extends Optimizer {
CombineFilters,
PushPredicateThroughProject,
PushPredicateThroughJoin,
- ColumnPruning) :: Nil
+ ColumnPruning) ::
+ Batch("LocalRelation", FixedPoint(100),
+ ConvertToLocalRelation) :: Nil
}
/**
@@ -610,3 +612,17 @@ object DecimalAggregates extends Rule[LogicalPlan] {
DecimalType(prec + 4, scale + 4))
}
}
+
+/**
+ * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to
+ * another LocalRelation.
+ *
+ * This is relatively simple as it currently handles only a single case: Project.
+ */
+object ConvertToLocalRelation extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Project(projectList, LocalRelation(output, data)) =>
+ val projection = new InterpretedProjection(projectList, output)
+ LocalRelation(projectList.map(_.toAttribute), data.map(projection))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8d30528328946..7cf4b81274906 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -29,12 +29,15 @@ import org.apache.spark.sql.catalyst.trees
/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
* corresponding statistic produced by the children. To override this behavior, override
- * `statistics` and assign it an overriden version of `Statistics`.
+ * `statistics` and assign it an overridden version of `Statistics`.
*
- * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
+ * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the
* performance of the implementations. The reason is that estimations might get triggered in
* performance-critical processes, such as query plan planning.
*
+ * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in
+ * cardinality estimation (e.g. cartesian joins).
+ *
* @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
* defaults to the product of children's `sizeInBytes`.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
new file mode 100644
index 0000000000000..cf42d43823399
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class ConvertToLocalRelationSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("LocalRelation", FixedPoint(100),
+ ConvertToLocalRelation) :: Nil
+ }
+
+ test("Project on LocalRelation should be turned into a single LocalRelation") {
+ val testRelation = LocalRelation(
+ LocalRelation('a.int, 'b.int).output,
+ Row(1, 2) ::
+ Row(4, 5) :: Nil)
+
+ val correctAnswer = LocalRelation(
+ LocalRelation('a1.int, 'b1.int).output,
+ Row(1, 3) ::
+ Row(4, 6) :: Nil)
+
+ val projectOnLocal = testRelation.select(
+ UnresolvedAttribute("a").as("a1"),
+ (UnresolvedAttribute("b") + 1).as("b1"))
+
+ val optimized = Optimize(projectOnLocal.analyze)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 8ad6526f872e5..17ea3cde8e50e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -102,7 +102,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
- def toDataFrame(colName: String, colNames: String*): DataFrame
+ def toDataFrame(colNames: String*): DataFrame
/** Returns the schema of this [[DataFrame]]. */
def schema: StructType
@@ -116,6 +116,25 @@ trait DataFrame extends RDDApi[Row] {
/** Prints the schema to the console in a nice tree format. */
def printSchema(): Unit
+ /**
+ * Returns true if the `collect` and `take` methods can be run locally
+ * (without any Spark executors).
+ */
+ def isLocal: Boolean
+
+ /**
+ * Displays the [[DataFrame]] in a tabular form. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ */
+ def show(): Unit
+
/**
* Cartesian join with another [[DataFrame]].
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 789bcf6184b3e..fa05a5dcac6bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -90,14 +90,13 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- override def toDataFrame(colName: String, colNames: String*): DataFrame = {
- val newNames = colName +: colNames
- require(schema.size == newNames.size,
+ override def toDataFrame(colNames: String*): DataFrame = {
+ require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
- "New column names: " + newNames.mkString(", "))
+ "New column names: " + colNames.mkString(", "))
- val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
+ val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
apply(oldName).as(newName)
}
select(newCols :_*)
@@ -113,6 +112,38 @@ private[sql] class DataFrameImpl protected[sql](
override def printSchema(): Unit = println(schema.treeString)
+ override def isLocal: Boolean = {
+ logicalPlan.isInstanceOf[LocalRelation]
+ }
+
+ override def show(): Unit = {
+ val data = take(20)
+ val numCols = schema.fieldNames.length
+
+ // For cells that are beyond 20 characters, replace it with the first 17 and "..."
+ val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ row.toSeq.map { cell =>
+ val str = if (cell == null) "null" else cell.toString
+ if (str.length > 20) str.substring(0, 17) + "..." else str
+ } : Seq[String]
+ }
+
+ // Compute the width of each column
+ val colWidths = Array.fill(numCols)(0)
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Pad the cells and print them
+ println(rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ String.format(s"%-${colWidths(i)}s", cell)
+ }.mkString(" ")
+ }.mkString("\n"))
+ }
+
override def join(right: DataFrame): DataFrame = {
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 6043fb4dee01d..782f6e28eebb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -48,7 +48,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
protected[sql] override def logicalPlan: LogicalPlan = err()
- override def toDataFrame(colName: String, colNames: String*): DataFrame = err()
+ override def toDataFrame(colNames: String*): DataFrame = err()
override def schema: StructType = err()
@@ -58,6 +58,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def printSchema(): Unit = err()
+ override def show(): Unit = err()
+
+ override def isLocal: Boolean = false
+
override def join(right: DataFrame): DataFrame = err()
override def join(right: DataFrame, joinExprs: Column): DataFrame = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 66aed5d5113d1..4dc506c21ab9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,9 +17,6 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.runtime.universe.TypeTag
-
import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
@@ -40,7 +37,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
val resuableProjection = buildProjection()
iter.map(resuableProjection)
}
@@ -55,7 +52,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
@transient lazy val conditionEvaluator = newPredicate(condition, child.output)
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
iter.filter(conditionEvaluator)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index a640ba57e0885..5eecc303ef72b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -87,13 +87,13 @@ trait CreatableRelationProvider {
/**
* ::DeveloperApi::
- * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must
- * be able to produce the schema of their data in the form of a [[StructType]] Concrete
+ * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must
+ * be able to produce the schema of their data in the form of a [[StructType]]. Concrete
* implementation should inherit from one of the descendant `Scan` classes, which define various
* abstract methods for execution.
*
* BaseRelations must also define a equality function that only returns true when the two
- * instances will return the same data. This equality function is used when determining when
+ * instances will return the same data. This equality function is used when determining when
* it is safe to substitute cached results for a given relation.
*/
@DeveloperApi
@@ -102,13 +102,16 @@ abstract class BaseRelation {
def schema: StructType
/**
- * Returns an estimated size of this relation in bytes. This information is used by the planner
+ * Returns an estimated size of this relation in bytes. This information is used by the planner
* to decided when it is safe to broadcast a relation and can be overridden by sources that
* know the size ahead of time. By default, the system will assume that tables are too
- * large to broadcast. This method will be called multiple times during query planning
+ * large to broadcast. This method will be called multiple times during query planning
* and thus should not perform expensive operations for each invocation.
+ *
+ * Note that it is always better to overestimate size than underestimate, because underestimation
+ * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table).
*/
- def sizeInBytes = sqlContext.conf.defaultSizeInBytes
+ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
}
/**
From c17161189d57f2e3a8d3550ea59a68edf487c8b7 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley"
Date: Sun, 8 Feb 2015 21:07:36 -0800
Subject: [PATCH 005/817] [SPARK-5660][MLLIB] Make Matrix apply public
This is #4447 with `override`.
Closes #4447
Author: Joseph K. Bradley
Author: Xiangrui Meng
Closes #4462 from mengxr/SPARK-5660 and squashes the following commits:
f82c8d6 [Xiangrui Meng] add override to matrix.apply
91cedde [Joseph K. Bradley] made matrix apply public
---
.../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 84f8ac2e0d9cd..c8a97b8c53d9b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -50,7 +50,7 @@ sealed trait Matrix extends Serializable {
private[mllib] def toBreeze: BM[Double]
/** Gets the (i, j)-th element. */
- private[mllib] def apply(i: Int, j: Int): Double
+ def apply(i: Int, j: Int): Double
/** Return the index for the (i, j)-th element in the backing array. */
private[mllib] def index(i: Int, j: Int): Int
@@ -163,7 +163,7 @@ class DenseMatrix(
private[mllib] def apply(i: Int): Double = values(i)
- private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j))
+ override def apply(i: Int, j: Int): Double = values(index(i, j))
private[mllib] def index(i: Int, j: Int): Int = {
if (!isTransposed) i + numRows * j else j + numCols * i
@@ -398,7 +398,7 @@ class SparseMatrix(
}
}
- private[mllib] def apply(i: Int, j: Int): Double = {
+ override def apply(i: Int, j: Int): Double = {
val ind = index(i, j)
if (ind < 0) 0.0 else values(ind)
}
From 4396dfb37f433ef186e3e0a09db9906986ec940b Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Sun, 8 Feb 2015 21:08:50 -0800
Subject: [PATCH 006/817] SPARK-4405 [MLLIB] Matrices.* construction methods
should check for rows x cols overflow
Check that size of dense matrix array is not beyond Int.MaxValue in Matrices.* methods. jkbradley this should be an easy one. Review and/or merge as you see fit.
Author: Sean Owen
Closes #4461 from srowen/SPARK-4405 and squashes the following commits:
c67574e [Sean Owen] Check that size of dense matrix array is not beyond Int.MaxValue in Matrices.* methods
---
.../org/apache/spark/mllib/linalg/Matrices.scala | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index c8a97b8c53d9b..89b38679b7494 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -256,8 +256,11 @@ object DenseMatrix {
* @param numCols number of columns of the matrix
* @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros
*/
- def zeros(numRows: Int, numCols: Int): DenseMatrix =
+ def zeros(numRows: Int, numCols: Int): DenseMatrix = {
+ require(numRows.toLong * numCols <= Int.MaxValue,
+ s"$numRows x $numCols dense matrix is too large to allocate")
new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols))
+ }
/**
* Generate a `DenseMatrix` consisting of ones.
@@ -265,8 +268,11 @@ object DenseMatrix {
* @param numCols number of columns of the matrix
* @return `DenseMatrix` with size `numRows` x `numCols` and values of ones
*/
- def ones(numRows: Int, numCols: Int): DenseMatrix =
+ def ones(numRows: Int, numCols: Int): DenseMatrix = {
+ require(numRows.toLong * numCols <= Int.MaxValue,
+ s"$numRows x $numCols dense matrix is too large to allocate")
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0))
+ }
/**
* Generate an Identity Matrix in `DenseMatrix` format.
@@ -291,6 +297,8 @@ object DenseMatrix {
* @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
*/
def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = {
+ require(numRows.toLong * numCols <= Int.MaxValue,
+ s"$numRows x $numCols dense matrix is too large to allocate")
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble()))
}
@@ -302,6 +310,8 @@ object DenseMatrix {
* @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
*/
def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = {
+ require(numRows.toLong * numCols <= Int.MaxValue,
+ s"$numRows x $numCols dense matrix is too large to allocate")
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian()))
}
From 4575c5643a82818bf64f9648314bdc2fdc12febb Mon Sep 17 00:00:00 2001
From: Hung Lin
Date: Sun, 8 Feb 2015 22:36:42 -0800
Subject: [PATCH 007/817] [SPARK-5472][SQL] Fix Scala code style
Fix Scala code style.
Author: Hung Lin
Closes #4464 from hunglin/SPARK-5472 and squashes the following commits:
ef7a3b3 [Hung Lin] SPARK-5472: fix scala style
---
.../org/apache/spark/sql/jdbc/JDBCRDD.scala | 42 +++++++++----------
.../apache/spark/sql/jdbc/JDBCRelation.scala | 35 +++++++++-------
2 files changed, 41 insertions(+), 36 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index a2f94675fb5a3..0bec32cca1325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -17,13 +17,10 @@
package org.apache.spark.sql.jdbc
-import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException}
-import scala.collection.mutable.ArrayBuffer
+import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException}
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.NextIterator
-import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
@@ -100,7 +97,7 @@ private[sql] object JDBCRDD extends Logging {
try {
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
- var fields = new Array[StructField](ncols);
+ val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnName(i + 1)
@@ -176,23 +173,27 @@ private[sql] object JDBCRDD extends Logging {
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
- def scanTable(sc: SparkContext,
- schema: StructType,
- driver: String,
- url: String,
- fqTable: String,
- requiredColumns: Array[String],
- filters: Array[Filter],
- parts: Array[Partition]): RDD[Row] = {
+ def scanTable(
+ sc: SparkContext,
+ schema: StructType,
+ driver: String,
+ url: String,
+ fqTable: String,
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ parts: Array[Partition]): RDD[Row] = {
+
val prunedSchema = pruneSchema(schema, requiredColumns)
- return new JDBCRDD(sc,
- getConnector(driver, url),
- prunedSchema,
- fqTable,
- requiredColumns,
- filters,
- parts)
+ return new
+ JDBCRDD(
+ sc,
+ getConnector(driver, url),
+ prunedSchema,
+ fqTable,
+ requiredColumns,
+ filters,
+ parts)
}
}
@@ -412,6 +413,5 @@ private[sql] class JDBCRDD(
gotNext = false
nextValue
}
-
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index e09125e406ba2..66ad38eb7c45b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -96,7 +96,8 @@ private[sql] class DefaultSource extends RelationProvider {
if (driver != null) Class.forName(driver)
- if ( partitionColumn != null
+ if (
+ partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
}
@@ -104,30 +105,34 @@ private[sql] class DefaultSource extends RelationProvider {
val partitionInfo = if (partitionColumn == null) {
null
} else {
- JDBCPartitioningInfo(partitionColumn,
- lowerBound.toLong, upperBound.toLong,
- numPartitions.toInt)
+ JDBCPartitioningInfo(
+ partitionColumn,
+ lowerBound.toLong,
+ upperBound.toLong,
+ numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(url, table, parts)(sqlContext)
}
}
-private[sql] case class JDBCRelation(url: String,
- table: String,
- parts: Array[Partition])(
- @transient val sqlContext: SQLContext)
- extends PrunedFilteredScan {
+private[sql] case class JDBCRelation(
+ url: String,
+ table: String,
+ parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan {
override val schema = JDBCRDD.resolveTable(url, table)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
- JDBCRDD.scanTable(sqlContext.sparkContext,
- schema,
- driver, url,
- table,
- requiredColumns, filters,
- parts)
+ JDBCRDD.scanTable(
+ sqlContext.sparkContext,
+ schema,
+ driver,
+ url,
+ table,
+ requiredColumns,
+ filters,
+ parts)
}
}
From 855d12ac0a9cdade4cd2cc64c4e7209478be6690 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng
Date: Sun, 8 Feb 2015 23:40:36 -0800
Subject: [PATCH 008/817] [SPARK-5539][MLLIB] LDA guide
This is the LDA user guide from jkbradley with Java and Scala code example.
Author: Xiangrui Meng
Author: Joseph K. Bradley
Closes #4465 from mengxr/lda-guide and squashes the following commits:
6dcb7d1 [Xiangrui Meng] update java example in the user guide
76169ff [Xiangrui Meng] update java example
36c3ae2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lda-guide
c2a1efe [Joseph K. Bradley] Added LDA programming guide, plus Java example (which is in the guide and probably should be removed).
---
data/mllib/sample_lda_data.txt | 12 ++
docs/mllib-clustering.md | 129 +++++++++++++++++-
.../spark/examples/mllib/JavaLDAExample.java | 75 ++++++++++
3 files changed, 215 insertions(+), 1 deletion(-)
create mode 100644 data/mllib/sample_lda_data.txt
create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
diff --git a/data/mllib/sample_lda_data.txt b/data/mllib/sample_lda_data.txt
new file mode 100644
index 0000000000000..2e76702ca9d67
--- /dev/null
+++ b/data/mllib/sample_lda_data.txt
@@ -0,0 +1,12 @@
+1 2 6 0 2 3 1 1 0 0 3
+1 3 0 1 3 0 0 2 0 0 1
+1 4 1 0 0 4 9 0 1 2 0
+2 1 0 3 0 0 5 0 2 3 9
+3 1 1 9 3 0 2 0 0 1 3
+4 2 0 3 4 5 1 1 1 4 0
+2 1 0 3 0 0 5 0 2 2 9
+1 1 1 9 2 1 2 0 0 1 3
+4 4 0 3 4 2 1 3 0 0 0
+2 8 2 0 3 0 2 0 2 7 2
+1 1 1 9 0 2 2 0 0 3 3
+4 1 0 0 4 5 1 3 0 1 0
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 1e9ef345b7435..99ed6b60e3f00 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -55,7 +55,7 @@ has the following parameters:
Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm:
-* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
+* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
* calculates the principal eigenvalue and eigenvector
* Clusters each of the input points according to their principal eigenvector component value
@@ -71,6 +71,35 @@ Example outputs for a dataset inspired by the paper - but with five clusters ins
+### Latent Dirichlet Allocation (LDA)
+
+[Latent Dirichlet Allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)
+is a topic model which infers topics from a collection of text documents.
+LDA can be thought of as a clustering algorithm as follows:
+
+* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset.
+* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts.
+* Rather than estimating a clustering using a traditional distance, LDA uses a function based
+ on a statistical model of how text documents are generated.
+
+LDA takes in a collection of documents as vectors of word counts.
+It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
+on the likelihood function. After fitting on the documents, LDA provides:
+
+* Topics: Inferred topics, each of which is a probability distribution over terms (words).
+* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics.
+
+LDA takes the following parameters:
+
+* `k`: Number of topics (i.e., cluster centers)
+* `maxIterations`: Limit on the number of iterations of EM used for learning
+* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions.
+* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions.
+* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery.
+
+*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet
+support prediction on new documents, and it does not have a Python API. These will be added in the future.
+
### Examples
#### k-means
@@ -293,6 +322,104 @@ for i in range(2):
+#### Latent Dirichlet Allocation (LDA) Example
+
+In the following example, we load word count vectors representing a corpus of documents.
+We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA)
+to infer three topics from the documents. The number of desired clusters is passed
+to the algorithm. We then output the topics, represented as probability distributions over words.
+
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.linalg.Vectors
+
+// Load and parse the data
+val data = sc.textFile("data/mllib/sample_lda_data.txt")
+val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
+// Index documents with unique IDs
+val corpus = parsedData.zipWithIndex.map(_.swap).cache()
+
+// Cluster the documents into three topics using LDA
+val ldaModel = new LDA().setK(3).run(corpus)
+
+// Output topics. Each is a distribution over words (matching word count vectors)
+println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):")
+val topics = ldaModel.topicsMatrix
+for (topic <- Range(0, 3)) {
+ print("Topic " + topic + ":")
+ for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
+ println()
+}
+{% endhighlight %}
+
+
+
+{% highlight java %}
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.DistributedLDAModel;
+import org.apache.spark.mllib.clustering.LDA;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.SparkConf;
+
+public class JavaLDAExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("LDA Example");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+
+ // Load and parse the data
+ String path = "data/mllib/sample_lda_data.txt";
+ JavaRDD data = sc.textFile(path);
+ JavaRDD parsedData = data.map(
+ new Function() {
+ public Vector call(String s) {
+ String[] sarray = s.trim().split(" ");
+ double[] values = new double[sarray.length];
+ for (int i = 0; i < sarray.length; i++)
+ values[i] = Double.parseDouble(sarray[i]);
+ return Vectors.dense(values);
+ }
+ }
+ );
+ // Index documents with unique IDs
+ JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
+ new Function, Tuple2>() {
+ public Tuple2 call(Tuple2 doc_id) {
+ return doc_id.swap();
+ }
+ }
+ ));
+ corpus.cache();
+
+ // Cluster the documents into three topics using LDA
+ DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
+
+ // Output topics. Each is a distribution over words (matching word count vectors)
+ System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
+ + " words):");
+ Matrix topics = ldaModel.topicsMatrix();
+ for (int topic = 0; topic < 3; topic++) {
+ System.out.print("Topic " + topic + ":");
+ for (int word = 0; word < ldaModel.vocabSize(); word++) {
+ System.out.print(" " + topics.apply(word, topic));
+ }
+ System.out.println();
+ }
+ }
+}
+{% endhighlight %}
+
+
+
+
+
In order to run the above application, follow the instructions
provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
section of the Spark
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
new file mode 100644
index 0000000000000..f394ff2084463
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.api.java.*;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.DistributedLDAModel;
+import org.apache.spark.mllib.clustering.LDA;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.SparkConf;
+
+public class JavaLDAExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("LDA Example");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+
+ // Load and parse the data
+ String path = "data/mllib/sample_lda_data.txt";
+ JavaRDD data = sc.textFile(path);
+ JavaRDD parsedData = data.map(
+ new Function() {
+ public Vector call(String s) {
+ String[] sarray = s.trim().split(" ");
+ double[] values = new double[sarray.length];
+ for (int i = 0; i < sarray.length; i++)
+ values[i] = Double.parseDouble(sarray[i]);
+ return Vectors.dense(values);
+ }
+ }
+ );
+ // Index documents with unique IDs
+ JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
+ new Function, Tuple2>() {
+ public Tuple2 call(Tuple2 doc_id) {
+ return doc_id.swap();
+ }
+ }
+ ));
+ corpus.cache();
+
+ // Cluster the documents into three topics using LDA
+ DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
+
+ // Output topics. Each is a distribution over words (matching word count vectors)
+ System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
+ + " words):");
+ Matrix topics = ldaModel.topicsMatrix();
+ for (int topic = 0; topic < 3; topic++) {
+ System.out.print("Topic " + topic + ":");
+ for (int word = 0; word < ldaModel.vocabSize(); word++) {
+ System.out.print(" " + topics.apply(word, topic));
+ }
+ System.out.println();
+ }
+ }
+}
From 4dfe180fc893bee1146161f8b2a6efd4d6d2bb8c Mon Sep 17 00:00:00 2001
From: Nicholas Chammas
Date: Mon, 9 Feb 2015 09:44:53 +0000
Subject: [PATCH 009/817] [SPARK-5473] [EC2] Expose SSH failures after status
checks pass
If there is some fatal problem with launching a cluster, `spark-ec2` just hangs without giving the user useful feedback on what the problem is.
This PR exposes the output of the SSH calls to the user if the SSH test fails during cluster launch for any reason but the instance status checks are all green. It also removes the growing trail of dots while waiting in favor of a fixed 3 dots.
For example:
```
$ ./ec2/spark-ec2 -k key -i /incorrect/path/identity.pem --instance-type m3.medium --slaves 1 --zone us-east-1c launch "spark-test"
Setting up security groups...
Searching for existing cluster spark-test...
Spark AMI: ami-35b1885c
Launching instances...
Launched 1 slaves in us-east-1c, regid = r-7dadd096
Launched master in us-east-1c, regid = r-fcadd017
Waiting for cluster to enter 'ssh-ready' state...
Warning: SSH connection error. (This could be temporary.)
Host: 127.0.0.1
SSH return code: 255
SSH output: Warning: Identity file /incorrect/path/identity.pem not accessible: No such file or directory.
Warning: Permanently added '127.0.0.1' (RSA) to the list of known hosts.
Permission denied (publickey).
```
This should give users enough information when some unrecoverable error occurs during launch so they can know to abort the launch. This will help avoid situations like the ones reported [here on Stack Overflow](http://stackoverflow.com/q/28002443/) and [here on the user list](http://mail-archives.apache.org/mod_mbox/spark-user/201501.mbox/%3C1422323829398-21381.postn3.nabble.com%3E), where the users couldn't tell what the problem was because it was being hidden by `spark-ec2`.
This is a usability improvement that should be backported to 1.2.
Resolves [SPARK-5473](https://issues.apache.org/jira/browse/SPARK-5473).
Author: Nicholas Chammas
Closes #4262 from nchammas/expose-ssh-failure and squashes the following commits:
8bda6ed [Nicholas Chammas] default to print SSH output
2b92534 [Nicholas Chammas] show SSH output after status check pass
---
ec2/spark_ec2.py | 36 ++++++++++++++++++++++++------------
1 file changed, 24 insertions(+), 12 deletions(-)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 725b1e47e0cea..87b2112fe4628 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -34,6 +34,7 @@
import sys
import tarfile
import tempfile
+import textwrap
import time
import urllib2
import warnings
@@ -681,21 +682,32 @@ def setup_spark_cluster(master, opts):
print "Ganglia started at http://%s:5080/ganglia" % master
-def is_ssh_available(host, opts):
+def is_ssh_available(host, opts, print_ssh_output=True):
"""
Check if SSH is available on a host.
"""
- try:
- with open(os.devnull, 'w') as devnull:
- ret = subprocess.check_call(
- ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3',
- '%s@%s' % (opts.user, host), stringify_command('true')],
- stdout=devnull,
- stderr=devnull
- )
- return ret == 0
- except subprocess.CalledProcessError as e:
- return False
+ s = subprocess.Popen(
+ ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3',
+ '%s@%s' % (opts.user, host), stringify_command('true')],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order
+ )
+ cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout
+
+ if s.returncode != 0 and print_ssh_output:
+ # extra leading newline is for spacing in wait_for_cluster_state()
+ print textwrap.dedent("""\n
+ Warning: SSH connection error. (This could be temporary.)
+ Host: {h}
+ SSH return code: {r}
+ SSH output: {o}
+ """).format(
+ h=host,
+ r=s.returncode,
+ o=cmd_output.strip()
+ )
+
+ return s.returncode == 0
def is_cluster_ssh_available(cluster_instances, opts):
From 0793ee1b4dea1f4b0df749e8ad7c1ab70b512faf Mon Sep 17 00:00:00 2001
From: Sandy Ryza
Date: Mon, 9 Feb 2015 10:12:12 +0000
Subject: [PATCH 010/817] SPARK-2149. [MLLIB] Univariate kernel density
estimation
Author: Sandy Ryza
Closes #1093 from sryza/sandy-spark-2149 and squashes the following commits:
5f06b33 [Sandy Ryza] More review comments
0f73060 [Sandy Ryza] Respond to Sean's review comments
0dfa005 [Sandy Ryza] SPARK-2149. Univariate kernel density estimation
---
.../spark/mllib/stat/KernelDensity.scala | 71 +++++++++++++++++++
.../apache/spark/mllib/stat/Statistics.scala | 14 ++++
.../spark/mllib/stat/KernelDensitySuite.scala | 47 ++++++++++++
3 files changed, 132 insertions(+)
create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
new file mode 100644
index 0000000000000..0deef11b4511a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.apache.spark.rdd.RDD
+
+private[stat] object KernelDensity {
+ /**
+ * Given a set of samples from a distribution, estimates its density at the set of given points.
+ * Uses a Gaussian kernel with the given standard deviation.
+ */
+ def estimate(samples: RDD[Double], standardDeviation: Double,
+ evaluationPoints: Array[Double]): Array[Double] = {
+ if (standardDeviation <= 0.0) {
+ throw new IllegalArgumentException("Standard deviation must be positive")
+ }
+
+ // This gets used in each Gaussian PDF computation, so compute it up front
+ val logStandardDeviationPlusHalfLog2Pi =
+ Math.log(standardDeviation) + 0.5 * Math.log(2 * Math.PI)
+
+ val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
+ (x, y) => {
+ var i = 0
+ while (i < evaluationPoints.length) {
+ x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
+ evaluationPoints(i))
+ i += 1
+ }
+ (x._1, i)
+ },
+ (x, y) => {
+ var i = 0
+ while (i < evaluationPoints.length) {
+ x._1(i) += y._1(i)
+ i += 1
+ }
+ (x._1, x._2 + y._2)
+ })
+
+ var i = 0
+ while (i < points.length) {
+ points(i) /= count
+ i += 1
+ }
+ points
+ }
+
+ private def normPdf(mean: Double, standardDeviation: Double,
+ logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
+ val x0 = x - mean
+ val x1 = x0 / standardDeviation
+ val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi
+ Math.exp(logDensity)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index b3fad0c52d655..32561620ac914 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -149,4 +149,18 @@ object Statistics {
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
ChiSqTest.chiSquaredFeatures(data)
}
+
+ /**
+ * Given an empirical distribution defined by the input RDD of samples, estimate its density at
+ * each of the given evaluation points using a Gaussian kernel.
+ *
+ * @param samples The samples RDD used to define the empirical distribution.
+ * @param standardDeviation The standard deviation of the kernel Gaussians.
+ * @param evaluationPoints The points at which to estimate densities.
+ * @return An array the same size as evaluationPoints with the density at each point.
+ */
+ def kernelDensity(samples: RDD[Double], standardDeviation: Double,
+ evaluationPoints: Iterable[Double]): Array[Double] = {
+ KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
new file mode 100644
index 0000000000000..f6a1e19f50296
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.scalatest.FunSuite
+
+import org.apache.commons.math3.distribution.NormalDistribution
+
+import org.apache.spark.mllib.util.LocalClusterSparkContext
+
+class KernelDensitySuite extends FunSuite with LocalClusterSparkContext {
+ test("kernel density single sample") {
+ val rdd = sc.parallelize(Array(5.0))
+ val evaluationPoints = Array(5.0, 6.0)
+ val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
+ val normal = new NormalDistribution(5.0, 3.0)
+ val acceptableErr = 1e-6
+ assert(densities(0) - normal.density(5.0) < acceptableErr)
+ assert(densities(0) - normal.density(6.0) < acceptableErr)
+ }
+
+ test("kernel density multiple samples") {
+ val rdd = sc.parallelize(Array(5.0, 10.0))
+ val evaluationPoints = Array(5.0, 6.0)
+ val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
+ val normal1 = new NormalDistribution(5.0, 3.0)
+ val normal2 = new NormalDistribution(10.0, 3.0)
+ val acceptableErr = 1e-6
+ assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr)
+ assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr)
+ }
+}
From de7806048ac49a8bfdf44d8f87bc11cea1dfb242 Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Mon, 9 Feb 2015 10:33:57 -0800
Subject: [PATCH 011/817] SPARK-4267 [YARN] Failing to launch jobs on Spark on
YARN with Hadoop 2.5.0 or later
Before passing to YARN, escape arguments in "extraJavaOptions" args, in order to correctly handle cases like -Dfoo="one two three". Also standardize how these args are handled and ensure that individual args are treated as stand-alone args, not one string.
vanzin andrewor14
Author: Sean Owen
Closes #4452 from srowen/SPARK-4267.2 and squashes the following commits:
c8297d2 [Sean Owen] Before passing to YARN, escape arguments in "extraJavaOptions" args, in order to correctly handle cases like -Dfoo="one two three". Also standardize how these args are handled and ensure that individual args are treated as stand-alone args, not one string.
---
.../org/apache/spark/deploy/yarn/Client.scala | 9 +++++----
.../spark/deploy/yarn/ExecutorRunnable.scala | 17 +++++++++--------
.../spark/deploy/yarn/YarnClusterSuite.scala | 6 ++++--
3 files changed, 18 insertions(+), 14 deletions(-)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index e7005094b5f3c..8afc1ccdad732 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -435,10 +435,11 @@ private[spark] class Client(
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
- sparkConf.getOption("spark.driver.extraJavaOptions")
+ val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions")
.orElse(sys.env.get("SPARK_JAVA_OPTS"))
- .map(Utils.splitCommandString).getOrElse(Seq.empty)
- .foreach(opts => javaOpts += opts)
+ driverOpts.foreach { opts =>
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ }
val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"),
sys.props.get("spark.driver.libraryPath")).flatten
if (libraryPaths.nonEmpty) {
@@ -460,7 +461,7 @@ private[spark] class Client(
val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')."
throw new SparkException(msg)
}
- javaOpts ++= Utils.splitCommandString(opts)
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 408cf09b9bdfa..7cd8c5f0f9204 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -128,14 +128,15 @@ class ExecutorRunnable(
// Set the JVM memory
val executorMemoryString = executorMemory + "m"
- javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
+ javaOpts += "-Xms" + executorMemoryString
+ javaOpts += "-Xmx" + executorMemoryString
// Set extra Java options for the executor, if defined
sys.props.get("spark.executor.extraJavaOptions").foreach { opts =>
- javaOpts += opts
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
sys.env.get("SPARK_JAVA_OPTS").foreach { opts =>
- javaOpts += opts
+ javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
sys.props.get("spark.executor.extraLibraryPath").foreach { p =>
prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p)))
@@ -173,11 +174,11 @@ class ExecutorRunnable(
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use
// %20the%20Concurrent%20Low%20Pause%20Collector|outline
- javaOpts += " -XX:+UseConcMarkSweepGC "
- javaOpts += " -XX:+CMSIncrementalMode "
- javaOpts += " -XX:+CMSIncrementalPacing "
- javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 "
- javaOpts += " -XX:CMSIncrementalDutyCycle=10 "
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
}
*/
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index eda40efc4c77f..e39de82740b1d 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -75,6 +75,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
private var oldConf: Map[String, String] = _
override def beforeAll() {
+ super.beforeAll()
+
tempDir = Utils.createTempDir()
val logConfDir = new File(tempDir, "log4j")
@@ -129,8 +131,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
sys.props += ("spark.executor.instances" -> "1")
sys.props += ("spark.driver.extraClassPath" -> childClasspath)
sys.props += ("spark.executor.extraClassPath" -> childClasspath)
-
- super.beforeAll()
+ sys.props += ("spark.executor.extraJavaOptions" -> "-Dfoo=\"one two three\"")
+ sys.props += ("spark.driver.extraJavaOptions" -> "-Dfoo=\"one two three\"")
}
override def afterAll() {
From afb131637d96e1e5e07eb8abf24e32e7f3b2304d Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Mon, 9 Feb 2015 11:42:52 -0800
Subject: [PATCH 012/817] [SPARK-5678] Convert DataFrame to pandas.DataFrame
and Series
```
pyspark.sql.DataFrame.to_pandas = to_pandas(self) unbound pyspark.sql.DataFrame method
Collect all the rows and return a `pandas.DataFrame`.
>>> df.to_pandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
pyspark.sql.Column.to_pandas = to_pandas(self) unbound pyspark.sql.Column method
Return a pandas.Series from the column
>>> df.age.to_pandas() # doctest: +SKIP
0 2
1 5
dtype: int64
```
Not tests by jenkins (they depends on pandas)
Author: Davies Liu
Closes #4476 from davies/to_pandas and squashes the following commits:
6276fb6 [Davies Liu] Convert DataFrame to pandas.DataFrame and Series
---
python/pyspark/sql.py | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e55f285a778c4..6a6dfbc5851b8 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2284,6 +2284,18 @@ def addColumn(self, colName, col):
"""
return self.select('*', col.alias(colName))
+ def to_pandas(self):
+ """
+ Collect all the rows and return a `pandas.DataFrame`.
+
+ >>> df.to_pandas() # doctest: +SKIP
+ age name
+ 0 2 Alice
+ 1 5 Bob
+ """
+ import pandas as pd
+ return pd.DataFrame.from_records(self.collect(), columns=self.columns)
+
# Having SchemaRDD for backward compatibility (for docs)
class SchemaRDD(DataFrame):
@@ -2551,6 +2563,19 @@ def cast(self, dataType):
jc = self._jc.cast(jdt)
return Column(jc, self.sql_ctx)
+ def to_pandas(self):
+ """
+ Return a pandas.Series from the column
+
+ >>> df.age.to_pandas() # doctest: +SKIP
+ 0 2
+ 1 5
+ dtype: int64
+ """
+ import pandas as pd
+ data = [c for c, in self.collect()]
+ return pd.Series(data)
+
def _aggregate_func(name, doc=""):
""" Create a function for aggregator by name"""
From dae216147f2247fd722fb0909da74fe71cf2fa8b Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh
Date: Mon, 9 Feb 2015 11:45:12 -0800
Subject: [PATCH 013/817] [SPARK-5664][BUILD] Restore stty settings when
exiting from SBT's spark-shell
For launching spark-shell from SBT.
Author: Liang-Chi Hsieh
Closes #4451 from viirya/restore_stty and squashes the following commits:
fdfc480 [Liang-Chi Hsieh] Restore stty settings when exit (for launching spark-shell from SBT).
---
build/sbt | 28 ++++++++++++++++++++++++++++
build/sbt-launch-lib.bash | 2 +-
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/build/sbt b/build/sbt
index 28ebb64f7197c..cc3203d79bccd 100755
--- a/build/sbt
+++ b/build/sbt
@@ -125,4 +125,32 @@ loadConfigFile() {
[[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@"
[[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@"
+exit_status=127
+saved_stty=""
+
+restoreSttySettings() {
+ stty $saved_stty
+ saved_stty=""
+}
+
+onExit() {
+ if [[ "$saved_stty" != "" ]]; then
+ restoreSttySettings
+ fi
+ exit $exit_status
+}
+
+saveSttySettings() {
+ saved_stty=$(stty -g 2>/dev/null)
+ if [[ ! $? ]]; then
+ saved_stty=""
+ fi
+}
+
+saveSttySettings
+trap onExit INT
+
run "$@"
+
+exit_status=$?
+onExit
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 5e0c640fa5919..504be48b358fa 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -81,7 +81,7 @@ execRunner () {
echo ""
}
- exec "$@"
+ "$@"
}
addJava () {
From 6fe70d8432314f0b7290a66f114306f61e0a87cc Mon Sep 17 00:00:00 2001
From: mcheah
Date: Mon, 9 Feb 2015 13:20:14 -0800
Subject: [PATCH 014/817] [SPARK-5691] Fixing wrong data structure lookup for
dupe app registratio...
In Master's registerApplication method, it checks if the application had
already registered by examining the addressToWorker hash map. In reality,
it should refer to the addressToApp data structure, as this is what
really tracks which apps have been registered.
Author: mcheah
Closes #4477 from mccheah/spark-5691 and squashes the following commits:
efdc573 [mcheah] [SPARK-5691] Fixing wrong data structure lookup for dupe app registration
---
core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index b8b1a25abff2e..53e453990f8c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -671,7 +671,7 @@ private[spark] class Master(
def registerApplication(app: ApplicationInfo): Unit = {
val appAddress = app.driver.path.address
- if (addressToWorker.contains(appAddress)) {
+ if (addressToApp.contains(appAddress)) {
logInfo("Attempted to re-register application at same address: " + appAddress)
return
}
From 0765af9b21e9204c410c7a849c7201bc3eda8cc3 Mon Sep 17 00:00:00 2001
From: Hari Shreedharan
Date: Mon, 9 Feb 2015 14:17:14 -0800
Subject: [PATCH 015/817] [SPARK-4905][STREAMING] FlumeStreamSuite fix.
Using String constructor instead of CharsetDecoder to see if it fixes the issue of empty strings in Flume test output.
Author: Hari Shreedharan
Closes #4371 from harishreedharan/Flume-stream-attempted-fix and squashes the following commits:
550d363 [Hari Shreedharan] Fix imports.
8695950 [Hari Shreedharan] Use Charsets.UTF_8 instead of "UTF-8" in String constructors.
af3ba14 [Hari Shreedharan] [SPARK-4905][STREAMING] FlumeStreamSuite fix.
---
.../apache/spark/streaming/flume/FlumeStreamSuite.scala | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index f333e3891b5f0..322de7bf2fed8 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.streaming.flume
import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
-import java.nio.charset.Charset
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps
+import com.google.common.base.Charsets
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.flume.source.avro
@@ -108,7 +108,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
val inputEvents = input.map { item =>
val event = new AvroFlumeEvent
- event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8")))
+ event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8)))
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
event
}
@@ -138,14 +138,13 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
status should be (avro.Status.OK)
}
- val decoder = Charset.forName("UTF-8").newDecoder()
eventually(timeout(10 seconds), interval(100 milliseconds)) {
val outputEvents = outputBuffer.flatten.map { _.event }
outputEvents.foreach {
event =>
event.getHeaders.get("test") should be("header")
}
- val output = outputEvents.map(event => decoder.decode(event.getBody()).toString)
+ val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8))
output should be (input)
}
}
From f48199eb354d6ec8675c2c1f96c3005064058d66 Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Mon, 9 Feb 2015 14:51:46 -0800
Subject: [PATCH 016/817] [SPARK-5675][SQL] XyzType companion object should
subclass XyzType
Otherwise, the following will always return false in Java.
```scala
dataType instanceof StringType
```
Author: Reynold Xin
Closes #4463 from rxin/type-companion-object and squashes the following commits:
04d5d8d [Reynold Xin] Comment.
976e11e [Reynold Xin] [SPARK-5675][SQL]StringType case object should be subclass of StringType class
---
.../apache/spark/sql/types/dataTypes.scala | 85 ++++++++++++++++---
1 file changed, 73 insertions(+), 12 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index 91efe320546a7..2abb1caee9cd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -240,10 +240,16 @@ abstract class DataType {
* @group dataType
*/
@DeveloperApi
-case object NullType extends DataType {
+class NullType private() extends DataType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "NullType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
override def defaultSize: Int = 1
}
+case object NullType extends NullType
+
+
protected[sql] object NativeType {
val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
@@ -292,7 +298,10 @@ protected[sql] abstract class NativeType extends DataType {
* @group dataType
*/
@DeveloperApi
-case object StringType extends NativeType with PrimitiveType {
+class StringType private() extends NativeType with PrimitiveType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "StringType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -303,6 +312,8 @@ case object StringType extends NativeType with PrimitiveType {
override def defaultSize: Int = 4096
}
+case object StringType extends StringType
+
/**
* :: DeveloperApi ::
@@ -313,7 +324,10 @@ case object StringType extends NativeType with PrimitiveType {
* @group dataType
*/
@DeveloperApi
-case object BinaryType extends NativeType with PrimitiveType {
+class BinaryType private() extends NativeType with PrimitiveType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Array[Byte]
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = new Ordering[JvmType] {
@@ -332,6 +346,8 @@ case object BinaryType extends NativeType with PrimitiveType {
override def defaultSize: Int = 4096
}
+case object BinaryType extends BinaryType
+
/**
* :: DeveloperApi ::
@@ -341,7 +357,10 @@ case object BinaryType extends NativeType with PrimitiveType {
*@group dataType
*/
@DeveloperApi
-case object BooleanType extends NativeType with PrimitiveType {
+class BooleanType private() extends NativeType with PrimitiveType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -352,6 +371,8 @@ case object BooleanType extends NativeType with PrimitiveType {
override def defaultSize: Int = 1
}
+case object BooleanType extends BooleanType
+
/**
* :: DeveloperApi ::
@@ -362,7 +383,10 @@ case object BooleanType extends NativeType with PrimitiveType {
* @group dataType
*/
@DeveloperApi
-case object TimestampType extends NativeType {
+class TimestampType private() extends NativeType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Timestamp
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
@@ -377,6 +401,8 @@ case object TimestampType extends NativeType {
override def defaultSize: Int = 12
}
+case object TimestampType extends TimestampType
+
/**
* :: DeveloperApi ::
@@ -387,7 +413,10 @@ case object TimestampType extends NativeType {
* @group dataType
*/
@DeveloperApi
-case object DateType extends NativeType {
+class DateType private() extends NativeType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "DateType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Int
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
@@ -400,6 +429,8 @@ case object DateType extends NativeType {
override def defaultSize: Int = 4
}
+case object DateType extends DateType
+
abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
@@ -438,7 +469,10 @@ protected[sql] sealed abstract class IntegralType extends NumericType {
* @group dataType
*/
@DeveloperApi
-case object LongType extends IntegralType {
+class LongType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "LongType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Long
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Long]]
@@ -453,6 +487,8 @@ case object LongType extends IntegralType {
override def simpleString = "bigint"
}
+case object LongType extends LongType
+
/**
* :: DeveloperApi ::
@@ -462,7 +498,10 @@ case object LongType extends IntegralType {
* @group dataType
*/
@DeveloperApi
-case object IntegerType extends IntegralType {
+class IntegerType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Int
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Int]]
@@ -477,6 +516,8 @@ case object IntegerType extends IntegralType {
override def simpleString = "int"
}
+case object IntegerType extends IntegerType
+
/**
* :: DeveloperApi ::
@@ -486,7 +527,10 @@ case object IntegerType extends IntegralType {
* @group dataType
*/
@DeveloperApi
-case object ShortType extends IntegralType {
+class ShortType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "ShortType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Short
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Short]]
@@ -501,6 +545,8 @@ case object ShortType extends IntegralType {
override def simpleString = "smallint"
}
+case object ShortType extends ShortType
+
/**
* :: DeveloperApi ::
@@ -510,7 +556,10 @@ case object ShortType extends IntegralType {
* @group dataType
*/
@DeveloperApi
-case object ByteType extends IntegralType {
+class ByteType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "ByteType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Byte
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Byte]]
@@ -525,6 +574,8 @@ case object ByteType extends IntegralType {
override def simpleString = "tinyint"
}
+case object ByteType extends ByteType
+
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
protected[sql] object FractionalType {
@@ -630,7 +681,10 @@ object DecimalType {
* @group dataType
*/
@DeveloperApi
-case object DoubleType extends FractionalType {
+class DoubleType private() extends FractionalType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Double
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Double]]
@@ -644,6 +698,8 @@ case object DoubleType extends FractionalType {
override def defaultSize: Int = 8
}
+case object DoubleType extends DoubleType
+
/**
* :: DeveloperApi ::
@@ -653,7 +709,10 @@ case object DoubleType extends FractionalType {
* @group dataType
*/
@DeveloperApi
-case object FloatType extends FractionalType {
+class FloatType private() extends FractionalType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "FloatType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
private[sql] type JvmType = Float
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val numeric = implicitly[Numeric[Float]]
@@ -667,6 +726,8 @@ case object FloatType extends FractionalType {
override def defaultSize: Int = 4
}
+case object FloatType extends FloatType
+
object ArrayType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
From b884daa58084d4f42e2318894067565b94e07f9d Mon Sep 17 00:00:00 2001
From: Florian Verhein
Date: Mon, 9 Feb 2015 23:47:07 +0000
Subject: [PATCH 017/817] [SPARK-5611] [EC2] Allow spark-ec2 repo and branch to
be set on CLI of spark_ec2.py
and by extension, the ami-list
Useful for using alternate spark-ec2 repos or branches.
Author: Florian Verhein
Closes #4385 from florianverhein/master and squashes the following commits:
7e2b4be [Florian Verhein] [SPARK-5611] [EC2] typo
8b653dc [Florian Verhein] [SPARK-5611] [EC2] Enforce only supporting spark-ec2 forks from github, log improvement
bc4b0ed [Florian Verhein] [SPARK-5611] allow spark-ec2 repos with different names
8b5c551 [Florian Verhein] improve option naming, fix logging, fix lint failing, add guard to enforce spark-ec2
7724308 [Florian Verhein] [SPARK-5611] [EC2] fixes
b42b68c [Florian Verhein] [SPARK-5611] [EC2] Allow spark-ec2 repo and branch to be set on CLI of spark_ec2.py
---
ec2/spark_ec2.py | 37 ++++++++++++++++++++++++++++++++-----
1 file changed, 32 insertions(+), 5 deletions(-)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 87b2112fe4628..3e4c49c0e1db6 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -62,10 +62,10 @@
DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION
DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark"
-MESOS_SPARK_EC2_BRANCH = "branch-1.3"
-# A URL prefix from which to fetch AMI information
-AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
+# Default location to get the spark-ec2 scripts (and ami-list) from
+DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2"
+DEFAULT_SPARK_EC2_BRANCH = "branch-1.3"
def setup_boto():
@@ -147,6 +147,14 @@ def parse_args():
"--spark-git-repo",
default=DEFAULT_SPARK_GITHUB_REPO,
help="Github repo from which to checkout supplied commit hash (default: %default)")
+ parser.add_option(
+ "--spark-ec2-git-repo",
+ default=DEFAULT_SPARK_EC2_GITHUB_REPO,
+ help="Github repo from which to checkout spark-ec2 (default: %default)")
+ parser.add_option(
+ "--spark-ec2-git-branch",
+ default=DEFAULT_SPARK_EC2_BRANCH,
+ help="Github repo branch of spark-ec2 to use (default: %default)")
parser.add_option(
"--hadoop-major-version", default="1",
help="Major version of Hadoop (default: %default)")
@@ -333,7 +341,12 @@ def get_spark_ami(opts):
print >> stderr,\
"Don't recognize %s, assuming type is pvm" % opts.instance_type
- ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type)
+ # URL prefix from which to fetch AMI information
+ ami_prefix = "{r}/{b}/ami-list".format(
+ r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1),
+ b=opts.spark_ec2_git_branch)
+
+ ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type)
try:
ami = urllib2.urlopen(ami_path).read().strip()
print "Spark AMI: " + ami
@@ -650,12 +663,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
+ print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
+ r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)
ssh(
host=master,
opts=opts,
command="rm -rf spark-ec2"
+ " && "
- + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH)
+ + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo,
+ b=opts.spark_ec2_git_branch)
)
print "Deploying files to master..."
@@ -1038,6 +1054,17 @@ def real_main():
print >> stderr, "ebs-vol-num cannot be greater than 8"
sys.exit(1)
+ # Prevent breaking ami_prefix (/, .git and startswith checks)
+ # Prevent forks with non spark-ec2 names for now.
+ if opts.spark_ec2_git_repo.endswith("/") or \
+ opts.spark_ec2_git_repo.endswith(".git") or \
+ not opts.spark_ec2_git_repo.startswith("https://github.com") or \
+ not opts.spark_ec2_git_repo.endswith("spark-ec2"):
+ print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \
+ "trailing / or .git. " \
+ "Furthermore, we currently only support forks named spark-ec2."
+ sys.exit(1)
+
try:
conn = ec2.connect_to_region(opts.region)
except Exception as e:
From 68b25cf695e0fce9e465288d5a053e540a3fccb4 Mon Sep 17 00:00:00 2001
From: Michael Armbrust
Date: Mon, 9 Feb 2015 16:02:56 -0800
Subject: [PATCH 018/817] [SQL] Add some missing DataFrame functions.
- as with a `Symbol`
- distinct
- sqlContext.emptyDataFrame
- move add/remove col out of RDDApi section
Author: Michael Armbrust
Closes #4437 from marmbrus/dfMissingFuncs and squashes the following commits:
2004023 [Michael Armbrust] Add missing functions
---
.../scala/org/apache/spark/sql/Column.scala | 9 +++++
.../org/apache/spark/sql/DataFrame.scala | 12 +++++--
.../org/apache/spark/sql/DataFrameImpl.scala | 34 +++++++++++--------
.../apache/spark/sql/IncomputableColumn.scala | 10 +++---
.../scala/org/apache/spark/sql/RDDApi.scala | 2 ++
.../org/apache/spark/sql/SQLContext.scala | 5 ++-
6 files changed, 51 insertions(+), 21 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 878b2b0556de7..1011bf0bb5ef4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -549,6 +549,15 @@ trait Column extends DataFrame {
*/
override def as(alias: String): Column = exprToColumn(Alias(expr, alias)())
+ /**
+ * Gives the column an alias.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".as('colB))
+ * }}}
+ */
+ override def as(alias: Symbol): Column = exprToColumn(Alias(expr, alias.name)())
+
/**
* Casts the column to a different data type.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 17ea3cde8e50e..6abfb7853cf1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -156,7 +156,7 @@ trait DataFrame extends RDDApi[Row] {
def join(right: DataFrame, joinExprs: Column): DataFrame
/**
- * Join with another [[DataFrame]], usin g the given join expression. The following performs
+ * Join with another [[DataFrame]], using the given join expression. The following performs
* a full outer join between `df1` and `df2`.
*
* {{{
@@ -233,7 +233,12 @@ trait DataFrame extends RDDApi[Row] {
/**
* Returns a new [[DataFrame]] with an alias set.
*/
- def as(name: String): DataFrame
+ def as(alias: String): DataFrame
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] with an alias set.
+ */
+ def as(alias: Symbol): DataFrame
/**
* Selects a set of expressions.
@@ -516,6 +521,9 @@ trait DataFrame extends RDDApi[Row] {
*/
override def repartition(numPartitions: Int): DataFrame
+ /** Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. */
+ override def distinct: DataFrame
+
override def persist(): this.type
override def persist(newLevel: StorageLevel): this.type
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index fa05a5dcac6bf..73393295ab0a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -196,7 +196,9 @@ private[sql] class DataFrameImpl protected[sql](
}.toSeq :_*)
}
- override def as(name: String): DataFrame = Subquery(name, logicalPlan)
+ override def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
+
+ override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
override def select(cols: Column*): DataFrame = {
val exprs = cols.zipWithIndex.map {
@@ -215,7 +217,19 @@ private[sql] class DataFrameImpl protected[sql](
override def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
Column(new SqlParser().parseExpression(expr))
- } :_*)
+ }: _*)
+ }
+
+ override def addColumn(colName: String, col: Column): DataFrame = {
+ select(Column("*"), col.as(colName))
+ }
+
+ override def renameColumn(existingName: String, newName: String): DataFrame = {
+ val colNames = schema.map { field =>
+ val name = field.name
+ if (name == existingName) Column(name).as(newName) else Column(name)
+ }
+ select(colNames :_*)
}
override def filter(condition: Column): DataFrame = {
@@ -264,18 +278,8 @@ private[sql] class DataFrameImpl protected[sql](
}
/////////////////////////////////////////////////////////////////////////////
-
- override def addColumn(colName: String, col: Column): DataFrame = {
- select(Column("*"), col.as(colName))
- }
-
- override def renameColumn(existingName: String, newName: String): DataFrame = {
- val colNames = schema.map { field =>
- val name = field.name
- if (name == existingName) Column(name).as(newName) else Column(name)
- }
- select(colNames :_*)
- }
+ // RDD API
+ /////////////////////////////////////////////////////////////////////////////
override def head(n: Int): Array[Row] = limit(n).collect()
@@ -307,6 +311,8 @@ private[sql] class DataFrameImpl protected[sql](
sqlContext.applySchema(rdd.repartition(numPartitions), schema)
}
+ override def distinct: DataFrame = Distinct(logicalPlan)
+
override def persist(): this.type = {
sqlContext.cacheManager.cacheQuery(this)
this
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 782f6e28eebb0..0600dcc226b4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -86,6 +86,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def selectExpr(exprs: String*): DataFrame = err()
+ override def addColumn(colName: String, col: Column): DataFrame = err()
+
+ override def renameColumn(existingName: String, newName: String): DataFrame = err()
+
override def filter(condition: Column): DataFrame = err()
override def filter(conditionExpr: String): DataFrame = err()
@@ -110,10 +114,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
/////////////////////////////////////////////////////////////////////////////
- override def addColumn(colName: String, col: Column): DataFrame = err()
-
- override def renameColumn(existingName: String, newName: String): DataFrame = err()
-
override def head(n: Int): Array[Row] = err()
override def head(): Row = err()
@@ -140,6 +140,8 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def repartition(numPartitions: Int): DataFrame = err()
+ override def distinct: DataFrame = err()
+
override def persist(): this.type = err()
override def persist(newLevel: StorageLevel): this.type = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala
index 38e6382f171d5..df866fd1ad8ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala
@@ -60,4 +60,6 @@ private[sql] trait RDDApi[T] {
def first(): T
def repartition(numPartitions: Int): DataFrame
+
+ def distinct: DataFrame
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index bf3990671029e..97e3777f933e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
@@ -130,6 +130,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
val experimental: ExperimentalMethods = new ExperimentalMethods(this)
+ /** Returns a [[DataFrame]] with no rows or columns. */
+ lazy val emptyDataFrame = DataFrame(this, NoRelation)
+
/**
* A collection of methods for registering user-defined functions (UDF).
*
From 5f0b30e59cc6a3017168189d3aaf09402699dc3b Mon Sep 17 00:00:00 2001
From: Yin Huai
Date: Mon, 9 Feb 2015 16:20:42 -0800
Subject: [PATCH 019/817] [SQL] Code cleanup.
I added an unnecessary line of code in https://github.com/apache/spark/commit/13531dd97c08563e53dacdaeaf1102bdd13ef825.
My bad. Let's delete it.
Author: Yin Huai
Closes #4482 from yhuai/unnecessaryCode and squashes the following commits:
3645af0 [Yin Huai] Code cleanup.
---
.../org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala | 3 ---
1 file changed, 3 deletions(-)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index c23575fe96898..036efa84d7c85 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -351,9 +351,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|)
""".stripMargin)
- new Path("/Users/yhuai/Desktop/whatever")
-
-
val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable")
val filesystemPath = new Path(expectedPath)
val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration)
From b8080aa86d55e0467fd4328f10a2f0d6605e6cc6 Mon Sep 17 00:00:00 2001
From: Cheng Lian
Date: Mon, 9 Feb 2015 16:23:12 -0800
Subject: [PATCH 020/817] [SPARK-5696] [SQL] [HOTFIX] Asks HiveThriftServer2 to
re-initialize log4j using Hive configurations
In this way, log4j configurations overriden by jets3t-0.9.2.jar can be again overriden by Hive default log4j configurations.
This might not be the best solution for this issue since it requires users to use `hive-log4j.properties` rather than `log4j.properties` to initialize `HiveThriftServer2` logging configurations, which can be confusing. The main purpose of this PR is to fix Jenkins PR build.
[
](https://reviewable.io/reviews/apache/spark/4484)
Author: Cheng Lian
Closes #4484 from liancheng/spark-5696 and squashes the following commits:
df83956 [Cheng Lian] Hot fix: asks HiveThriftServer2 to re-initialize log4j using Hive configurations
---
.../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 3 +++
1 file changed, 3 insertions(+)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 6e07df18b0e15..525777aa454c4 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.thriftserver
import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.hive.common.LogUtils
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
@@ -54,6 +55,8 @@ object HiveThriftServer2 extends Logging {
System.exit(-1)
}
+ LogUtils.initHiveLog4j()
+
logInfo("Starting SparkContext")
SparkSQLEnv.init()
From 2a36292534a1e9f7a501e88f69bfc3a09fb62cb3 Mon Sep 17 00:00:00 2001
From: Lu Yan
Date: Mon, 9 Feb 2015 16:25:38 -0800
Subject: [PATCH 021/817] [SPARK-5614][SQL] Predicate pushdown through
Generate.
Now in Catalyst's rules, predicates can not be pushed through "Generate" nodes. Further more, partition pruning in HiveTableScan can not be applied on those queries involves "Generate". This makes such queries very inefficient. In practice, it finds patterns like
```scala
Filter(predicate, Generate(generator, _, _, _, grandChild))
```
and splits the predicate into 2 parts by referencing the generated column from Generate node or not. And a new Filter will be created for those conjuncts can be pushed beneath Generate node. If nothing left for the original Filter, it will be removed.
For example, physical plan for query
```sql
select len, bk
from s_server lateral view explode(len_arr) len_table as len
where len > 5 and day = '20150102';
```
where 'day' is a partition column in metastore is like this in current version of Spark SQL:
> Project [len, bk]
>
> Filter ((len > "5") && "(day = "20150102")")
>
> Generate explode(len_arr), true, false
>
> HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), None
But theoretically the plan should be like this
> Project [len, bk]
>
> Filter (len > "5")
>
> Generate explode(len_arr), true, false
>
> HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), Some(day = "20150102")
Where partition pruning predicates can be pushed to HiveTableScan nodes.
Author: Lu Yan
Closes #4394 from ianluyan/ppd and squashes the following commits:
a67dce9 [Lu Yan] Fix English grammar.
7cea911 [Lu Yan] Revised based on @marmbrus's opinions
ffc59fc [Lu Yan] [SPARK-5614][SQL] Predicate pushdown through Generate.
---
.../sql/catalyst/optimizer/Optimizer.scala | 25 ++++++++
.../optimizer/FilterPushdownSuite.scala | 63 ++++++++++++++++++-
2 files changed, 87 insertions(+), 1 deletion(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3bc48c95c5653..fd58b9681ea24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -50,6 +50,7 @@ object DefaultOptimizer extends Optimizer {
CombineFilters,
PushPredicateThroughProject,
PushPredicateThroughJoin,
+ PushPredicateThroughGenerate,
ColumnPruning) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) :: Nil
@@ -455,6 +456,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
}
}
+/**
+ * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
+ * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
+ */
+object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case filter @ Filter(condition,
+ generate @ Generate(generator, join, outer, alias, grandChild)) =>
+ // Predicates that reference attributes produced by the `Generate` operator cannot
+ // be pushed below the operator.
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
+ conjunct => conjunct.references subsetOf grandChild.outputSet
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownPredicate = pushDown.reduce(And)
+ val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
+ stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+ } else {
+ filter
+ }
+ }
+}
+
/**
* Pushes down [[Filter]] operators where the `condition` can be
* evaluated using only the attributes of the left or right side of a join. Other
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ebb123c1f909e..1158b5dfc6147 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
class FilterPushdownSuite extends PlanTest {
@@ -34,7 +36,8 @@ class FilterPushdownSuite extends PlanTest {
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughJoin) :: Nil
+ PushPredicateThroughJoin,
+ PushPredicateThroughGenerate) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -411,4 +414,62 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
}
+
+ val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+
+ test("generate: predicate referenced no generated column") {
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .where(('b >= 5) && ('a > 6))
+ }
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = {
+ testRelationWithArrayType
+ .where(('b >= 5) && ('a > 6))
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
+ }
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("generate: part of conjuncts referenced generated column") {
+ val generator = Explode(Seq("c"), 'c_arr)
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(generator, true, false, Some("arr"))
+ .where(('b >= 5) && ('c > 6))
+ }
+ val optimized = Optimize(originalQuery.analyze)
+ val referenceResult = {
+ testRelationWithArrayType
+ .where('b >= 5)
+ .generate(generator, true, false, Some("arr"))
+ .where('c > 6).analyze
+ }
+
+ // Since newly generated columns get different ids every time being analyzed
+ // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails.
+ // So we check operators manually here.
+ // Filter("c" > 6)
+ assertResult(classOf[Filter])(optimized.getClass)
+ assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size)
+ assertResult("c"){
+ optimized.asInstanceOf[Filter].condition.references.toSeq(0).name
+ }
+
+ // the rest part
+ comparePlans(optimized.children(0), referenceResult.children(0))
+ }
+
+ test("generate: all conjuncts referenced generated column") {
+ val originalQuery = {
+ testRelationWithArrayType
+ .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .where(('c > 6) || ('b > 5)).analyze
+ }
+ val optimized = Optimize(originalQuery)
+
+ comparePlans(optimized, originalQuery)
+ }
}
From 0ee53ebce9944722e76b2b28fae79d9956be9f17 Mon Sep 17 00:00:00 2001
From: Wenchen Fan
Date: Mon, 9 Feb 2015 16:39:34 -0800
Subject: [PATCH 022/817] [SPARK-2096][SQL] support dot notation on array of
struct
~~The rule is simple: If you want `a.b` work, then `a` must be some level of nested array of struct(level 0 means just a StructType). And the result of `a.b` is same level of nested array of b-type.
An optimization is: the resolve chain looks like `Attribute -> GetItem -> GetField -> GetField ...`, so we could transmit the nested array information between `GetItem` and `GetField` to avoid repeated computation of `innerDataType` and `containsNullList` of that nested array.~~
marmbrus Could you take a look?
to evaluate `a.b`, if `a` is array of struct, then `a.b` means get field `b` on each element of `a`, and return a result of array.
Author: Wenchen Fan
Closes #2405 from cloud-fan/nested-array-dot and squashes the following commits:
08a228a [Wenchen Fan] support dot notation on array of struct
---
.../sql/catalyst/analysis/Analyzer.scala | 30 +++++++++-------
.../catalyst/expressions/complexTypes.scala | 34 ++++++++++++++++---
.../sql/catalyst/optimizer/Optimizer.scala | 3 +-
.../ExpressionEvaluationSuite.scala | 2 +-
.../org/apache/spark/sql/json/JsonSuite.scala | 6 ++--
5 files changed, 53 insertions(+), 22 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b59ed1739566..fb2ff014cef07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
+ def findField(fields: Array[StructField]): Int = {
+ val checkField = (f: StructField) => resolver(f.name, fieldName)
+ val ordinal = fields.indexWhere(checkField)
+ if (ordinal == -1) {
+ sys.error(
+ s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+ } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+ sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ } else {
+ ordinal
+ }
+ }
expr.dataType match {
case StructType(fields) =>
- val actualField = fields.filter(f => resolver(f.name, fieldName))
- if (actualField.length == 0) {
- sys.error(
- s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
- } else if (actualField.length == 1) {
- val field = actualField(0)
- GetField(expr, field, fields.indexOf(field))
- } else {
- sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
- }
+ val ordinal = findField(fields)
+ StructGetField(expr, fields(ordinal), ordinal)
+ case ArrayType(StructType(fields), containsNull) =>
+ val ordinal = findField(fields)
+ ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 66e2e5c4bafce..68051a2a2007e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
}
}
+
+trait GetField extends UnaryExpression {
+ self: Product =>
+
+ type EvaluatedType = Any
+ override def foldable = child.foldable
+ override def toString = s"$child.${field.name}"
+
+ def field: StructField
+}
+
/**
* Returns the value of fields in the Struct `child`.
*/
-case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
- type EvaluatedType = Any
+case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
def dataType = field.dataType
override def nullable = child.nullable || field.nullable
- override def foldable = child.foldable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
+}
- override def toString = s"$child.${field.name}"
+/**
+ * Returns the array of value of fields in the Array of Struct `child`.
+ */
+case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
+ extends GetField {
+
+ def dataType = ArrayType(field.dataType, containsNull)
+ override def nullable = child.nullable
+
+ override def eval(input: Row): Any = {
+ val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
+ if (baseValue == null) null else {
+ baseValue.map { row =>
+ if (row == null) null else row(ordinal)
+ }
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fd58b9681ea24..0da081ed1a6e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -209,7 +209,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
- case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+ case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+ case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 7cf6c80194f6c..dcfd8b28cb02a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
- GetField(expr, field, fields.indexOf(field))
+ StructGetField(expr, field, fields.indexOf(field))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 926ba68828ee8..7870cf9b0a868 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -342,21 +342,19 @@ class JsonSuite extends QueryTest {
)
}
- ignore("Complex field and type inferring (Ignored)") {
+ test("GetField operation on complex data type") {
val jsonDF = jsonRDD(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
- // Right now, "field1" and "field2" are treated as aliases. We should fix it.
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
Row(true, "str1")
)
- // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
// Getting all values of a specific field from an array of structs.
checkAnswer(
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
- Row(Seq(true, false), Seq("str1", null))
+ Row(Seq(true, false, null), Seq("str1", null, null))
)
}
From d08e7c2b498584609cb3c7922eaaa2a0d115603f Mon Sep 17 00:00:00 2001
From: DoingDone9 <799203320@qq.com>
Date: Mon, 9 Feb 2015 16:40:26 -0800
Subject: [PATCH 023/817] [SPARK-5648][SQL] support "alter ... unset
tblproperties("key")"
make hivecontext support "alter ... unset tblproperties("key")"
like :
alter view viewName unset tblproperties("k")
alter table tableName unset tblproperties("k")
Author: DoingDone9 <799203320@qq.com>
Closes #4424 from DoingDone9/unset and squashes the following commits:
6dd8bee [DoingDone9] support "alter ... unset tblproperties("key")"
---
sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 ++
1 file changed, 2 insertions(+)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2a4b88092179f..f51af62d3340b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -103,6 +103,7 @@ private[hive] object HiveQl {
"TOK_CREATEINDEX",
"TOK_DROPDATABASE",
"TOK_DROPINDEX",
+ "TOK_DROPTABLE_PROPERTIES",
"TOK_MSCK",
"TOK_ALTERVIEW_ADDPARTS",
@@ -111,6 +112,7 @@ private[hive] object HiveQl {
"TOK_ALTERVIEW_PROPERTIES",
"TOK_ALTERVIEW_RENAME",
"TOK_CREATEVIEW",
+ "TOK_DROPVIEW_PROPERTIES",
"TOK_DROPVIEW",
"TOK_EXPORT",
From 3ec3ad295ddd1435da68251b7479ffb60aec7037 Mon Sep 17 00:00:00 2001
From: Cheng Lian
Date: Mon, 9 Feb 2015 16:52:05 -0800
Subject: [PATCH 024/817] [SPARK-5699] [SQL] [Tests] Runs hive-thriftserver
tests whenever SQL code is modified
[
](https://reviewable.io/reviews/apache/spark/4486)
Author: Cheng Lian
Closes #4486 from liancheng/spark-5699 and squashes the following commits:
538001d [Cheng Lian] Runs hive-thriftserver tests whenever SQL code is modified
---
dev/run-tests | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/dev/run-tests b/dev/run-tests
index 2257a566bb1bb..483958757a2dd 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -36,7 +36,7 @@ function handle_error () {
}
-# Build against the right verison of Hadoop.
+# Build against the right version of Hadoop.
{
if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then
@@ -77,7 +77,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl"
fi
}
-# Only run Hive tests if there are sql changes.
+# Only run Hive tests if there are SQL changes.
# Partial solution for SPARK-1455.
if [ -n "$AMPLAB_JENKINS" ]; then
git fetch origin master:master
@@ -183,7 +183,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
# will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
From d302c4800bf2f74eceb731169ddf1766136b7398 Mon Sep 17 00:00:00 2001
From: Andrew Or
Date: Mon, 9 Feb 2015 17:33:29 -0800
Subject: [PATCH 025/817] [SPARK-5698] Do not let user request negative # of
executors
Otherwise we might crash the ApplicationMaster. Why? Please see https://issues.apache.org/jira/browse/SPARK-5698.
sryza I believe this is also relevant in your patch #4168.
Author: Andrew Or
Closes #4483 from andrewor14/da-negative and squashes the following commits:
53ed955 [Andrew Or] Throw IllegalArgumentException instead
0e89fd5 [Andrew Or] Check against negative requests
---
.../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 9d2fb4f3b4729..f9ca93432bf41 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -314,6 +314,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* Return whether the request is acknowledged.
*/
final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ if (numAdditionalExecutors < 0) {
+ throw new IllegalArgumentException(
+ "Attempted to request a negative number of additional executor(s) " +
+ s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!")
+ }
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
logDebug(s"Number of pending executors is now $numPendingExecutors")
numPendingExecutors += numAdditionalExecutors
From 08488c175f2e8532cb6aab84da2abd9ad57179cc Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Mon, 9 Feb 2015 20:49:22 -0800
Subject: [PATCH 026/817] [SPARK-5469] restructure pyspark.sql into multiple
files
All the DataTypes moved into pyspark.sql.types
The changes can be tracked by `--find-copies-harder -M25`
```
davieslocalhost:~/work/spark/python$ git diff --find-copies-harder -M25 --numstat master..
2 5 python/docs/pyspark.ml.rst
0 3 python/docs/pyspark.mllib.rst
10 2 python/docs/pyspark.sql.rst
1 1 python/pyspark/mllib/linalg.py
21 14 python/pyspark/{mllib => sql}/__init__.py
14 2108 python/pyspark/{sql.py => sql/context.py}
10 1772 python/pyspark/{sql.py => sql/dataframe.py}
7 6 python/pyspark/{sql_tests.py => sql/tests.py}
8 1465 python/pyspark/{sql.py => sql/types.py}
4 2 python/run-tests
1 1 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
```
Also `git blame -C -C python/pyspark/sql/context.py` to track the history.
Author: Davies Liu
Closes #4479 from davies/sql and squashes the following commits:
1b5f0a5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sql
2b2b983 [Davies Liu] restructure pyspark.sql
---
python/docs/pyspark.ml.rst | 7 +-
python/docs/pyspark.mllib.rst | 3 -
python/docs/pyspark.sql.rst | 12 +-
python/pyspark/mllib/linalg.py | 2 +-
python/pyspark/sql.py | 2736 -----------------
python/pyspark/sql/__init__.py | 42 +
python/pyspark/sql/context.py | 642 ++++
python/pyspark/sql/dataframe.py | 974 ++++++
python/pyspark/{sql_tests.py => sql/tests.py} | 13 +-
python/pyspark/sql/types.py | 1279 ++++++++
python/run-tests | 6 +-
.../spark/sql/test/ExamplePointUDT.scala | 2 +-
12 files changed, 2962 insertions(+), 2756 deletions(-)
delete mode 100644 python/pyspark/sql.py
create mode 100644 python/pyspark/sql/__init__.py
create mode 100644 python/pyspark/sql/context.py
create mode 100644 python/pyspark/sql/dataframe.py
rename python/pyspark/{sql_tests.py => sql/tests.py} (96%)
create mode 100644 python/pyspark/sql/types.py
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
index f10d1339a9a8f..4da6d4a74a299 100644
--- a/python/docs/pyspark.ml.rst
+++ b/python/docs/pyspark.ml.rst
@@ -1,11 +1,8 @@
pyspark.ml package
=====================
-Submodules
-----------
-
-pyspark.ml module
------------------
+Module Context
+--------------
.. automodule:: pyspark.ml
:members:
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index 4548b8739ed91..21f66ca344a3c 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -1,9 +1,6 @@
pyspark.mllib package
=====================
-Submodules
-----------
-
pyspark.mllib.classification module
-----------------------------------
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index 65b3650ae10ab..80c6f02a9df41 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -1,10 +1,18 @@
pyspark.sql module
==================
-Module contents
----------------
+Module Context
+--------------
.. automodule:: pyspark.sql
:members:
:undoc-members:
:show-inheritance:
+
+
+pyspark.sql.types module
+------------------------
+.. automodule:: pyspark.sql.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 7f21190ed8c25..597012b1c967c 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,7 +29,7 @@
import numpy as np
-from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
IntegerType, ByteType
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
deleted file mode 100644
index 6a6dfbc5851b8..0000000000000
--- a/python/pyspark/sql.py
+++ /dev/null
@@ -1,2736 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-public classes of Spark SQL:
-
- - L{SQLContext}
- Main entry point for SQL functionality.
- - L{DataFrame}
- A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, DataFrames also support SQL.
- - L{GroupedData}
- - L{Column}
- Column is a DataFrame with a single column.
- - L{Row}
- A Row of data returned by a Spark SQL query.
- - L{HiveContext}
- Main entry point for accessing data stored in Apache Hive..
-"""
-
-import sys
-import itertools
-import decimal
-import datetime
-import keyword
-import warnings
-import json
-import re
-import random
-import os
-from tempfile import NamedTemporaryFile
-from array import array
-from operator import itemgetter
-from itertools import imap
-
-from py4j.protocol import Py4JError
-from py4j.java_collections import ListConverter, MapConverter
-
-from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _prepare_for_python_RDD
-from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
- CloudPickleSerializer, UTF8Deserializer
-from pyspark.storagelevel import StorageLevel
-from pyspark.traceback_utils import SCCallSiteSync
-
-
-__all__ = [
- "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
- "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
- "ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl",
- "SchemaRDD"]
-
-
-class DataType(object):
-
- """Spark SQL DataType"""
-
- def __repr__(self):
- return self.__class__.__name__
-
- def __hash__(self):
- return hash(str(self))
-
- def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- @classmethod
- def typeName(cls):
- return cls.__name__[:-4].lower()
-
- def jsonValue(self):
- return self.typeName()
-
- def json(self):
- return json.dumps(self.jsonValue(),
- separators=(',', ':'),
- sort_keys=True)
-
-
-class PrimitiveTypeSingleton(type):
-
- """Metaclass for PrimitiveType"""
-
- _instances = {}
-
- def __call__(cls):
- if cls not in cls._instances:
- cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
- return cls._instances[cls]
-
-
-class PrimitiveType(DataType):
-
- """Spark SQL PrimitiveType"""
-
- __metaclass__ = PrimitiveTypeSingleton
-
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
-
-class NullType(PrimitiveType):
-
- """Spark SQL NullType
-
- The data type representing None, used for the types which has not
- been inferred.
- """
-
-
-class StringType(PrimitiveType):
-
- """Spark SQL StringType
-
- The data type representing string values.
- """
-
-
-class BinaryType(PrimitiveType):
-
- """Spark SQL BinaryType
-
- The data type representing bytearray values.
- """
-
-
-class BooleanType(PrimitiveType):
-
- """Spark SQL BooleanType
-
- The data type representing bool values.
- """
-
-
-class DateType(PrimitiveType):
-
- """Spark SQL DateType
-
- The data type representing datetime.date values.
- """
-
-
-class TimestampType(PrimitiveType):
-
- """Spark SQL TimestampType
-
- The data type representing datetime.datetime values.
- """
-
-
-class DecimalType(DataType):
-
- """Spark SQL DecimalType
-
- The data type representing decimal.Decimal values.
- """
-
- def __init__(self, precision=None, scale=None):
- self.precision = precision
- self.scale = scale
- self.hasPrecisionInfo = precision is not None
-
- def jsonValue(self):
- if self.hasPrecisionInfo:
- return "decimal(%d,%d)" % (self.precision, self.scale)
- else:
- return "decimal"
-
- def __repr__(self):
- if self.hasPrecisionInfo:
- return "DecimalType(%d,%d)" % (self.precision, self.scale)
- else:
- return "DecimalType()"
-
-
-class DoubleType(PrimitiveType):
-
- """Spark SQL DoubleType
-
- The data type representing float values.
- """
-
-
-class FloatType(PrimitiveType):
-
- """Spark SQL FloatType
-
- The data type representing single precision floating-point values.
- """
-
-
-class ByteType(PrimitiveType):
-
- """Spark SQL ByteType
-
- The data type representing int values with 1 singed byte.
- """
-
-
-class IntegerType(PrimitiveType):
-
- """Spark SQL IntegerType
-
- The data type representing int values.
- """
-
-
-class LongType(PrimitiveType):
-
- """Spark SQL LongType
-
- The data type representing long values. If the any value is
- beyond the range of [-9223372036854775808, 9223372036854775807],
- please use DecimalType.
- """
-
-
-class ShortType(PrimitiveType):
-
- """Spark SQL ShortType
-
- The data type representing int values with 2 signed bytes.
- """
-
-
-class ArrayType(DataType):
-
- """Spark SQL ArrayType
-
- The data type representing list values. An ArrayType object
- comprises two fields, elementType (a DataType) and containsNull (a bool).
- The field of elementType is used to specify the type of array elements.
- The field of containsNull is used to specify if the array has None values.
-
- """
-
- def __init__(self, elementType, containsNull=True):
- """Creates an ArrayType
-
- :param elementType: the data type of elements.
- :param containsNull: indicates whether the list contains None values.
-
- >>> ArrayType(StringType) == ArrayType(StringType, True)
- True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
- False
- """
- self.elementType = elementType
- self.containsNull = containsNull
-
- def __repr__(self):
- return "ArrayType(%s,%s)" % (self.elementType,
- str(self.containsNull).lower())
-
- def jsonValue(self):
- return {"type": self.typeName(),
- "elementType": self.elementType.jsonValue(),
- "containsNull": self.containsNull}
-
- @classmethod
- def fromJson(cls, json):
- return ArrayType(_parse_datatype_json_value(json["elementType"]),
- json["containsNull"])
-
-
-class MapType(DataType):
-
- """Spark SQL MapType
-
- The data type representing dict values. A MapType object comprises
- three fields, keyType (a DataType), valueType (a DataType) and
- valueContainsNull (a bool).
-
- The field of keyType is used to specify the type of keys in the map.
- The field of valueType is used to specify the type of values in the map.
- The field of valueContainsNull is used to specify if values of this
- map has None values.
-
- For values of a MapType column, keys are not allowed to have None values.
-
- """
-
- def __init__(self, keyType, valueType, valueContainsNull=True):
- """Creates a MapType
- :param keyType: the data type of keys.
- :param valueType: the data type of values.
- :param valueContainsNull: indicates whether values contains
- null values.
-
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
- True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
- False
- """
- self.keyType = keyType
- self.valueType = valueType
- self.valueContainsNull = valueContainsNull
-
- def __repr__(self):
- return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
- str(self.valueContainsNull).lower())
-
- def jsonValue(self):
- return {"type": self.typeName(),
- "keyType": self.keyType.jsonValue(),
- "valueType": self.valueType.jsonValue(),
- "valueContainsNull": self.valueContainsNull}
-
- @classmethod
- def fromJson(cls, json):
- return MapType(_parse_datatype_json_value(json["keyType"]),
- _parse_datatype_json_value(json["valueType"]),
- json["valueContainsNull"])
-
-
-class StructField(DataType):
-
- """Spark SQL StructField
-
- Represents a field in a StructType.
- A StructField object comprises three fields, name (a string),
- dataType (a DataType) and nullable (a bool). The field of name
- is the name of a StructField. The field of dataType specifies
- the data type of a StructField.
-
- The field of nullable specifies if values of a StructField can
- contain None values.
-
- """
-
- def __init__(self, name, dataType, nullable=True, metadata=None):
- """Creates a StructField
- :param name: the name of this field.
- :param dataType: the data type of this field.
- :param nullable: indicates whether values of this field
- can be null.
- :param metadata: metadata of this field, which is a map from string
- to simple type that can be serialized to JSON
- automatically
-
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
- True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
- False
- """
- self.name = name
- self.dataType = dataType
- self.nullable = nullable
- self.metadata = metadata or {}
-
- def __repr__(self):
- return "StructField(%s,%s,%s)" % (self.name, self.dataType,
- str(self.nullable).lower())
-
- def jsonValue(self):
- return {"name": self.name,
- "type": self.dataType.jsonValue(),
- "nullable": self.nullable,
- "metadata": self.metadata}
-
- @classmethod
- def fromJson(cls, json):
- return StructField(json["name"],
- _parse_datatype_json_value(json["type"]),
- json["nullable"],
- json["metadata"])
-
-
-class StructType(DataType):
-
- """Spark SQL StructType
-
- The data type representing rows.
- A StructType object comprises a list of L{StructField}.
-
- """
-
- def __init__(self, fields):
- """Creates a StructType
-
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
- >>> struct1 == struct2
- True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
- >>> struct1 == struct2
- False
- """
- self.fields = fields
-
- def __repr__(self):
- return ("StructType(List(%s))" %
- ",".join(str(field) for field in self.fields))
-
- def jsonValue(self):
- return {"type": self.typeName(),
- "fields": [f.jsonValue() for f in self.fields]}
-
- @classmethod
- def fromJson(cls, json):
- return StructType([StructField.fromJson(f) for f in json["fields"]])
-
-
-class UserDefinedType(DataType):
- """
- .. note:: WARN: Spark Internal Use Only
- SQL User-Defined Type (UDT).
- """
-
- @classmethod
- def typeName(cls):
- return cls.__name__.lower()
-
- @classmethod
- def sqlType(cls):
- """
- Underlying SQL storage type for this UDT.
- """
- raise NotImplementedError("UDT must implement sqlType().")
-
- @classmethod
- def module(cls):
- """
- The Python module of the UDT.
- """
- raise NotImplementedError("UDT must implement module().")
-
- @classmethod
- def scalaUDT(cls):
- """
- The class name of the paired Scala UDT.
- """
- raise NotImplementedError("UDT must have a paired Scala UDT.")
-
- def serialize(self, obj):
- """
- Converts the a user-type object into a SQL datum.
- """
- raise NotImplementedError("UDT must implement serialize().")
-
- def deserialize(self, datum):
- """
- Converts a SQL datum into a user-type object.
- """
- raise NotImplementedError("UDT must implement deserialize().")
-
- def json(self):
- return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
-
- def jsonValue(self):
- schema = {
- "type": "udt",
- "class": self.scalaUDT(),
- "pyClass": "%s.%s" % (self.module(), type(self).__name__),
- "sqlType": self.sqlType().jsonValue()
- }
- return schema
-
- @classmethod
- def fromJson(cls, json):
- pyUDT = json["pyClass"]
- split = pyUDT.rfind(".")
- pyModule = pyUDT[:split]
- pyClass = pyUDT[split+1:]
- m = __import__(pyModule, globals(), locals(), [pyClass], -1)
- UDT = getattr(m, pyClass)
- return UDT()
-
- def __eq__(self, other):
- return type(self) == type(other)
-
-
-_all_primitive_types = dict((v.typeName(), v)
- for v in globals().itervalues()
- if type(v) is PrimitiveTypeSingleton and
- v.__base__ == PrimitiveType)
-
-
-_all_complex_types = dict((v.typeName(), v)
- for v in [ArrayType, MapType, StructType])
-
-
-def _parse_datatype_json_string(json_string):
- """Parses the given data type JSON string.
- >>> def check_datatype(datatype):
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
- ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
- ... return datatype == python_datatype
- >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
- True
- >>> # Simple ArrayType.
- >>> simple_arraytype = ArrayType(StringType(), True)
- >>> check_datatype(simple_arraytype)
- True
- >>> # Simple MapType.
- >>> simple_maptype = MapType(StringType(), LongType())
- >>> check_datatype(simple_maptype)
- True
- >>> # Simple StructType.
- >>> simple_structtype = StructType([
- ... StructField("a", DecimalType(), False),
- ... StructField("b", BooleanType(), True),
- ... StructField("c", LongType(), True),
- ... StructField("d", BinaryType(), False)])
- >>> check_datatype(simple_structtype)
- True
- >>> # Complex StructType.
- >>> complex_structtype = StructType([
- ... StructField("simpleArray", simple_arraytype, True),
- ... StructField("simpleMap", simple_maptype, True),
- ... StructField("simpleStruct", simple_structtype, True),
- ... StructField("boolean", BooleanType(), False),
- ... StructField("withMeta", DoubleType(), False, {"name": "age"})])
- >>> check_datatype(complex_structtype)
- True
- >>> # Complex ArrayType.
- >>> complex_arraytype = ArrayType(complex_structtype, True)
- >>> check_datatype(complex_arraytype)
- True
- >>> # Complex MapType.
- >>> complex_maptype = MapType(complex_structtype,
- ... complex_arraytype, False)
- >>> check_datatype(complex_maptype)
- True
- >>> check_datatype(ExamplePointUDT())
- True
- >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> check_datatype(structtype_with_udt)
- True
- """
- return _parse_datatype_json_value(json.loads(json_string))
-
-
-_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
-
-
-def _parse_datatype_json_value(json_value):
- if type(json_value) is unicode:
- if json_value in _all_primitive_types.keys():
- return _all_primitive_types[json_value]()
- elif json_value == u'decimal':
- return DecimalType()
- elif _FIXED_DECIMAL.match(json_value):
- m = _FIXED_DECIMAL.match(json_value)
- return DecimalType(int(m.group(1)), int(m.group(2)))
- else:
- raise ValueError("Could not parse datatype: %s" % json_value)
- else:
- tpe = json_value["type"]
- if tpe in _all_complex_types:
- return _all_complex_types[tpe].fromJson(json_value)
- elif tpe == 'udt':
- return UserDefinedType.fromJson(json_value)
- else:
- raise ValueError("not supported type: %s" % tpe)
-
-
-# Mapping Python types to Spark SQL DataType
-_type_mappings = {
- type(None): NullType,
- bool: BooleanType,
- int: IntegerType,
- long: LongType,
- float: DoubleType,
- str: StringType,
- unicode: StringType,
- bytearray: BinaryType,
- decimal.Decimal: DecimalType,
- datetime.date: DateType,
- datetime.datetime: TimestampType,
- datetime.time: TimestampType,
-}
-
-
-def _infer_type(obj):
- """Infer the DataType from obj
-
- >>> p = ExamplePoint(1.0, 2.0)
- >>> _infer_type(p)
- ExamplePointUDT
- """
- if obj is None:
- raise ValueError("Can not infer type for None")
-
- if hasattr(obj, '__UDT__'):
- return obj.__UDT__
-
- dataType = _type_mappings.get(type(obj))
- if dataType is not None:
- return dataType()
-
- if isinstance(obj, dict):
- for key, value in obj.iteritems():
- if key is not None and value is not None:
- return MapType(_infer_type(key), _infer_type(value), True)
- else:
- return MapType(NullType(), NullType(), True)
- elif isinstance(obj, (list, array)):
- for v in obj:
- if v is not None:
- return ArrayType(_infer_type(obj[0]), True)
- else:
- return ArrayType(NullType(), True)
- else:
- try:
- return _infer_schema(obj)
- except ValueError:
- raise ValueError("not supported type: %s" % type(obj))
-
-
-def _infer_schema(row):
- """Infer the schema from dict/namedtuple/object"""
- if isinstance(row, dict):
- items = sorted(row.items())
-
- elif isinstance(row, tuple):
- if hasattr(row, "_fields"): # namedtuple
- items = zip(row._fields, tuple(row))
- elif hasattr(row, "__FIELDS__"): # Row
- items = zip(row.__FIELDS__, tuple(row))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
- items = row
- else:
- raise ValueError("Can't infer schema from tuple")
-
- elif hasattr(row, "__dict__"): # object
- items = sorted(row.__dict__.items())
-
- else:
- raise ValueError("Can not infer schema for type: %s" % type(row))
-
- fields = [StructField(k, _infer_type(v), True) for k, v in items]
- return StructType(fields)
-
-
-def _need_python_to_sql_conversion(dataType):
- """
- Checks whether we need python to sql conversion for the given type.
- For now, only UDTs need this conversion.
-
- >>> _need_python_to_sql_conversion(DoubleType())
- False
- >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
- ... StructField("values", ArrayType(DoubleType(), False), False)])
- >>> _need_python_to_sql_conversion(schema0)
- False
- >>> _need_python_to_sql_conversion(ExamplePointUDT())
- True
- >>> schema1 = ArrayType(ExamplePointUDT(), False)
- >>> _need_python_to_sql_conversion(schema1)
- True
- >>> schema2 = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> _need_python_to_sql_conversion(schema2)
- True
- """
- if isinstance(dataType, StructType):
- return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
- elif isinstance(dataType, ArrayType):
- return _need_python_to_sql_conversion(dataType.elementType)
- elif isinstance(dataType, MapType):
- return _need_python_to_sql_conversion(dataType.keyType) or \
- _need_python_to_sql_conversion(dataType.valueType)
- elif isinstance(dataType, UserDefinedType):
- return True
- else:
- return False
-
-
-def _python_to_sql_converter(dataType):
- """
- Returns a converter that converts a Python object into a SQL datum for the given type.
-
- >>> conv = _python_to_sql_converter(DoubleType())
- >>> conv(1.0)
- 1.0
- >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
- >>> conv([1.0, 2.0])
- [1.0, 2.0]
- >>> conv = _python_to_sql_converter(ExamplePointUDT())
- >>> conv(ExamplePoint(1.0, 2.0))
- [1.0, 2.0]
- >>> schema = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> conv = _python_to_sql_converter(schema)
- >>> conv((1.0, ExamplePoint(1.0, 2.0)))
- (1.0, [1.0, 2.0])
- """
- if not _need_python_to_sql_conversion(dataType):
- return lambda x: x
-
- if isinstance(dataType, StructType):
- names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
- converters = map(_python_to_sql_converter, types)
-
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(c(obj.get(n)) for n, c in zip(names, converters))
- elif isinstance(obj, tuple):
- if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
- return tuple(c(v) for c, v in zip(converters, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
- d = dict(obj)
- return tuple(c(d.get(n)) for n, c in zip(names, converters))
- else:
- return tuple(c(v) for c, v in zip(converters, obj))
- else:
- raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
- return converter
- elif isinstance(dataType, ArrayType):
- element_converter = _python_to_sql_converter(dataType.elementType)
- return lambda a: [element_converter(v) for v in a]
- elif isinstance(dataType, MapType):
- key_converter = _python_to_sql_converter(dataType.keyType)
- value_converter = _python_to_sql_converter(dataType.valueType)
- return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
- elif isinstance(dataType, UserDefinedType):
- return lambda obj: dataType.serialize(obj)
- else:
- raise ValueError("Unexpected type %r" % dataType)
-
-
-def _has_nulltype(dt):
- """ Return whether there is NullType in `dt` or not """
- if isinstance(dt, StructType):
- return any(_has_nulltype(f.dataType) for f in dt.fields)
- elif isinstance(dt, ArrayType):
- return _has_nulltype((dt.elementType))
- elif isinstance(dt, MapType):
- return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
- else:
- return isinstance(dt, NullType)
-
-
-def _merge_type(a, b):
- if isinstance(a, NullType):
- return b
- elif isinstance(b, NullType):
- return a
- elif type(a) is not type(b):
- # TODO: type cast (such as int -> long)
- raise TypeError("Can not merge type %s and %s" % (a, b))
-
- # same type
- if isinstance(a, StructType):
- nfs = dict((f.name, f.dataType) for f in b.fields)
- fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
- for f in a.fields]
- names = set([f.name for f in fields])
- for n in nfs:
- if n not in names:
- fields.append(StructField(n, nfs[n]))
- return StructType(fields)
-
- elif isinstance(a, ArrayType):
- return ArrayType(_merge_type(a.elementType, b.elementType), True)
-
- elif isinstance(a, MapType):
- return MapType(_merge_type(a.keyType, b.keyType),
- _merge_type(a.valueType, b.valueType),
- True)
- else:
- return a
-
-
-def _create_converter(dataType):
- """Create an converter to drop the names of fields in obj """
- if isinstance(dataType, ArrayType):
- conv = _create_converter(dataType.elementType)
- return lambda row: map(conv, row)
-
- elif isinstance(dataType, MapType):
- kconv = _create_converter(dataType.keyType)
- vconv = _create_converter(dataType.valueType)
- return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
-
- elif isinstance(dataType, NullType):
- return lambda x: None
-
- elif not isinstance(dataType, StructType):
- return lambda x: x
-
- # dataType must be StructType
- names = [f.name for f in dataType.fields]
- converters = [_create_converter(f.dataType) for f in dataType.fields]
-
- def convert_struct(obj):
- if obj is None:
- return
-
- if isinstance(obj, tuple):
- if hasattr(obj, "_fields"):
- d = dict(zip(obj._fields, obj))
- elif hasattr(obj, "__FIELDS__"):
- d = dict(zip(obj.__FIELDS__, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- d = dict(obj)
- else:
- raise ValueError("unexpected tuple: %s" % str(obj))
-
- elif isinstance(obj, dict):
- d = obj
- elif hasattr(obj, "__dict__"): # object
- d = obj.__dict__
- else:
- raise ValueError("Unexpected obj: %s" % obj)
-
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
-
- return convert_struct
-
-
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
-def _split_schema_abstract(s):
- """
- split the schema abstract into fields
-
- >>> _split_schema_abstract("a b c")
- ['a', 'b', 'c']
- >>> _split_schema_abstract("a(a b)")
- ['a(a b)']
- >>> _split_schema_abstract("a b[] c{a b}")
- ['a', 'b[]', 'c{a b}']
- >>> _split_schema_abstract(" ")
- []
- """
-
- r = []
- w = ''
- brackets = []
- for c in s:
- if c == ' ' and not brackets:
- if w:
- r.append(w)
- w = ''
- else:
- w += c
- if c in _BRACKETS:
- brackets.append(c)
- elif c in _BRACKETS.values():
- if not brackets or c != _BRACKETS[brackets.pop()]:
- raise ValueError("unexpected " + c)
-
- if brackets:
- raise ValueError("brackets not closed: %s" % brackets)
- if w:
- r.append(w)
- return r
-
-
-def _parse_field_abstract(s):
- """
- Parse a field in schema abstract
-
- >>> _parse_field_abstract("a")
- StructField(a,None,true)
- >>> _parse_field_abstract("b(c d)")
- StructField(b,StructType(...c,None,true),StructField(d...
- >>> _parse_field_abstract("a[]")
- StructField(a,ArrayType(None,true),true)
- >>> _parse_field_abstract("a{[]}")
- StructField(a,MapType(None,ArrayType(None,true),true),true)
- """
- if set(_BRACKETS.keys()) & set(s):
- idx = min((s.index(c) for c in _BRACKETS if c in s))
- name = s[:idx]
- return StructField(name, _parse_schema_abstract(s[idx:]), True)
- else:
- return StructField(s, None, True)
-
-
-def _parse_schema_abstract(s):
- """
- parse abstract into schema
-
- >>> _parse_schema_abstract("a b c")
- StructType...a...b...c...
- >>> _parse_schema_abstract("a[b c] b{}")
- StructType...a,ArrayType...b...c...b,MapType...
- >>> _parse_schema_abstract("c{} d{a b}")
- StructType...c,MapType...d,MapType...a...b...
- >>> _parse_schema_abstract("a b(t)").fields[1]
- StructField(b,StructType(List(StructField(t,None,true))),true)
- """
- s = s.strip()
- if not s:
- return
-
- elif s.startswith('('):
- return _parse_schema_abstract(s[1:-1])
-
- elif s.startswith('['):
- return ArrayType(_parse_schema_abstract(s[1:-1]), True)
-
- elif s.startswith('{'):
- return MapType(None, _parse_schema_abstract(s[1:-1]))
-
- parts = _split_schema_abstract(s)
- fields = [_parse_field_abstract(p) for p in parts]
- return StructType(fields)
-
-
-def _infer_schema_type(obj, dataType):
- """
- Fill the dataType with types inferred from obj
-
- >>> schema = _parse_schema_abstract("a b c d")
- >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
- >>> _infer_schema_type(row, schema)
- StructType...IntegerType...DoubleType...StringType...DateType...
- >>> row = [[1], {"key": (1, 2.0)}]
- >>> schema = _parse_schema_abstract("a[] b{c d}")
- >>> _infer_schema_type(row, schema)
- StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
- """
- if dataType is None:
- return _infer_type(obj)
-
- if not obj:
- return NullType()
-
- if isinstance(dataType, ArrayType):
- eType = _infer_schema_type(obj[0], dataType.elementType)
- return ArrayType(eType, True)
-
- elif isinstance(dataType, MapType):
- k, v = obj.iteritems().next()
- return MapType(_infer_schema_type(k, dataType.keyType),
- _infer_schema_type(v, dataType.valueType))
-
- elif isinstance(dataType, StructType):
- fs = dataType.fields
- assert len(fs) == len(obj), \
- "Obj(%s) have different length with fields(%s)" % (obj, fs)
- fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
- for o, f in zip(obj, fs)]
- return StructType(fields)
-
- else:
- raise ValueError("Unexpected dataType: %s" % dataType)
-
-
-_acceptable_types = {
- BooleanType: (bool,),
- ByteType: (int, long),
- ShortType: (int, long),
- IntegerType: (int, long),
- LongType: (int, long),
- FloatType: (float,),
- DoubleType: (float,),
- DecimalType: (decimal.Decimal,),
- StringType: (str, unicode),
- BinaryType: (bytearray,),
- DateType: (datetime.date,),
- TimestampType: (datetime.datetime,),
- ArrayType: (list, tuple, array),
- MapType: (dict,),
- StructType: (tuple, list),
-}
-
-
-def _verify_type(obj, dataType):
- """
- Verify the type of obj against dataType, raise an exception if
- they do not match.
-
- >>> _verify_type(None, StructType([]))
- >>> _verify_type("", StringType())
- >>> _verify_type(0, IntegerType())
- >>> _verify_type(range(3), ArrayType(ShortType()))
- >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
- >>> _verify_type({}, MapType(StringType(), IntegerType()))
- >>> _verify_type((), StructType([]))
- >>> _verify_type([], StructType([]))
- >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
- >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
- >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
- """
- # all objects are nullable
- if obj is None:
- return
-
- if isinstance(dataType, UserDefinedType):
- if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
- raise ValueError("%r is not an instance of type %r" % (obj, dataType))
- _verify_type(dataType.serialize(obj), dataType.sqlType())
- return
-
- _type = type(dataType)
- assert _type in _acceptable_types, "unkown datatype: %s" % dataType
-
- # subclass of them can not be deserialized in JVM
- if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept object in type %s"
- % (dataType, type(obj)))
-
- if isinstance(dataType, ArrayType):
- for i in obj:
- _verify_type(i, dataType.elementType)
-
- elif isinstance(dataType, MapType):
- for k, v in obj.iteritems():
- _verify_type(k, dataType.keyType)
- _verify_type(v, dataType.valueType)
-
- elif isinstance(dataType, StructType):
- if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with"
- "length of fields (%d)" % (len(obj), len(dataType.fields)))
- for v, f in zip(obj, dataType.fields):
- _verify_type(v, f.dataType)
-
-
-_cached_cls = {}
-
-
-def _restore_object(dataType, obj):
- """ Restore object during unpickling. """
- # use id(dataType) as key to speed up lookup in dict
- # Because of batched pickling, dataType will be the
- # same object in most cases.
- k = id(dataType)
- cls = _cached_cls.get(k)
- if cls is None:
- # use dataType as key to avoid create multiple class
- cls = _cached_cls.get(dataType)
- if cls is None:
- cls = _create_cls(dataType)
- _cached_cls[dataType] = cls
- _cached_cls[k] = cls
- return cls(obj)
-
-
-def _create_object(cls, v):
- """ Create an customized object with class `cls`. """
- # datetime.date would be deserialized as datetime.datetime
- # from java type, so we need to set it back.
- if cls is datetime.date and isinstance(v, datetime.datetime):
- return v.date()
- return cls(v) if v is not None else v
-
-
-def _create_getter(dt, i):
- """ Create a getter for item `i` with schema """
- cls = _create_cls(dt)
-
- def getter(self):
- return _create_object(cls, self[i])
-
- return getter
-
-
-def _has_struct_or_date(dt):
- """Return whether `dt` is or has StructType/DateType in it"""
- if isinstance(dt, StructType):
- return True
- elif isinstance(dt, ArrayType):
- return _has_struct_or_date(dt.elementType)
- elif isinstance(dt, MapType):
- return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
- elif isinstance(dt, DateType):
- return True
- elif isinstance(dt, UserDefinedType):
- return True
- return False
-
-
-def _create_properties(fields):
- """Create properties according to fields"""
- ps = {}
- for i, f in enumerate(fields):
- name = f.name
- if (name.startswith("__") and name.endswith("__")
- or keyword.iskeyword(name)):
- warnings.warn("field name %s can not be accessed in Python,"
- "use position to access it instead" % name)
- if _has_struct_or_date(f.dataType):
- # delay creating object until accessing it
- getter = _create_getter(f.dataType, i)
- else:
- getter = itemgetter(i)
- ps[name] = property(getter)
- return ps
-
-
-def _create_cls(dataType):
- """
- Create an class by dataType
-
- The created class is similar to namedtuple, but can have nested schema.
-
- >>> schema = _parse_schema_abstract("a b c")
- >>> row = (1, 1.0, "str")
- >>> schema = _infer_schema_type(row, schema)
- >>> obj = _create_cls(schema)(row)
- >>> import pickle
- >>> pickle.loads(pickle.dumps(obj))
- Row(a=1, b=1.0, c='str')
-
- >>> row = [[1], {"key": (1, 2.0)}]
- >>> schema = _parse_schema_abstract("a[] b{c d}")
- >>> schema = _infer_schema_type(row, schema)
- >>> obj = _create_cls(schema)(row)
- >>> pickle.loads(pickle.dumps(obj))
- Row(a=[1], b={'key': Row(c=1, d=2.0)})
- >>> pickle.loads(pickle.dumps(obj.a))
- [1]
- >>> pickle.loads(pickle.dumps(obj.b))
- {'key': Row(c=1, d=2.0)}
- """
-
- if isinstance(dataType, ArrayType):
- cls = _create_cls(dataType.elementType)
-
- def List(l):
- if l is None:
- return
- return [_create_object(cls, v) for v in l]
-
- return List
-
- elif isinstance(dataType, MapType):
- kcls = _create_cls(dataType.keyType)
- vcls = _create_cls(dataType.valueType)
-
- def Dict(d):
- if d is None:
- return
- return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
-
- return Dict
-
- elif isinstance(dataType, DateType):
- return datetime.date
-
- elif isinstance(dataType, UserDefinedType):
- return lambda datum: dataType.deserialize(datum)
-
- elif not isinstance(dataType, StructType):
- # no wrapper for primitive types
- return lambda x: x
-
- class Row(tuple):
-
- """ Row in DataFrame """
- __DATATYPE__ = dataType
- __FIELDS__ = tuple(f.name for f in dataType.fields)
- __slots__ = ()
-
- # create property for fast access
- locals().update(_create_properties(dataType.fields))
-
- def asDict(self):
- """ Return as a dict """
- return dict((n, getattr(self, n)) for n in self.__FIELDS__)
-
- def __repr__(self):
- # call collect __repr__ for nested objects
- return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
- for n in self.__FIELDS__))
-
- def __reduce__(self):
- return (_restore_object, (self.__DATATYPE__, tuple(self)))
-
- return Row
-
-
-class SQLContext(object):
-
- """Main entry point for Spark SQL functionality.
-
- A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
- tables, execute SQL over tables, cache tables, and read parquet files.
- """
-
- def __init__(self, sparkContext, sqlContext=None):
- """Create a new SQLContext.
-
- :param sparkContext: The SparkContext to wrap.
- :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
- SQLContext in the JVM, instead we make all calls to this object.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
-
- >>> bad_rdd = sc.parallelize([1,2,3])
- >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
-
- >>> from datetime import datetime
- >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
- ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
- ... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.inferSchema(allTypes)
- >>> df.registerTempTable("allTypes")
- >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
- ... 'from allTypes where b and i > 0').collect()
- [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
- ... x.row.a, x.list)).collect()
- [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
- """
- self._sc = sparkContext
- self._jsc = self._sc._jsc
- self._jvm = self._sc._jvm
- self._scala_SQLContext = sqlContext
-
- @property
- def _ssql_ctx(self):
- """Accessor for the JVM Spark SQL context.
-
- Subclasses can override this property to provide their own
- JVM Contexts.
- """
- if self._scala_SQLContext is None:
- self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
- return self._scala_SQLContext
-
- def registerFunction(self, name, f, returnType=StringType()):
- """Registers a lambda function as a UDF so it can be used in SQL statements.
-
- In addition to a name and the function itself, the return type can be optionally specified.
- When the return type is not given it default to a string and conversion will automatically
- be done. For any other return type, the produced object must match the specified type.
-
- >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
- >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
- [Row(c0=u'4')]
- >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
- >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
- """
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
- self._ssql_ctx.udf().registerPython(name,
- bytearray(pickled_cmd),
- env,
- includes,
- self._sc.pythonExec,
- bvars,
- self._sc._javaAccumulator,
- returnType.json())
-
- def inferSchema(self, rdd, samplingRatio=None):
- """Infer and apply a schema to an RDD of L{Row}.
-
- When samplingRatio is specified, the schema is inferred by looking
- at the types of each row in the sampled dataset. Otherwise, the
- first 100 rows of the RDD are inspected. Nested collections are
- supported, which can include array, dict, list, Row, tuple,
- namedtuple, or object.
-
- Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
- Using top level dicts is deprecated, as dict is used to represent Maps.
-
- If a single column has multiple distinct inferred types, it may cause
- runtime exceptions.
-
- >>> rdd = sc.parallelize(
- ... [Row(field1=1, field2="row1"),
- ... Row(field1=2, field2="row2"),
- ... Row(field1=3, field2="row3")])
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.collect()[0]
- Row(field1=1, field2=u'row1')
-
- >>> NestedRow = Row("f1", "f2")
- >>> nestedRdd1 = sc.parallelize([
- ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
- ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> df = sqlCtx.inferSchema(nestedRdd1)
- >>> df.collect()
- [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
-
- >>> nestedRdd2 = sc.parallelize([
- ... NestedRow([[1, 2], [2, 3]], [1, 2]),
- ... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> df = sqlCtx.inferSchema(nestedRdd2)
- >>> df.collect()
- [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
-
- >>> from collections import namedtuple
- >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
- >>> rdd = sc.parallelize(
- ... [CustomRow(field1=1, field2="row1"),
- ... CustomRow(field1=2, field2="row2"),
- ... CustomRow(field1=3, field2="row3")])
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.collect()[0]
- Row(field1=1, field2=u'row1')
- """
-
- if isinstance(rdd, DataFrame):
- raise TypeError("Cannot apply schema to DataFrame")
-
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- warnings.warn("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio > 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
-
- converter = _create_converter(schema)
- rdd = rdd.map(converter)
- return self.applySchema(rdd, schema)
-
- def applySchema(self, rdd, schema):
- """
- Applies the given schema to the given RDD of L{tuple} or L{list}.
-
- These tuples or lists can contain complex nested structures like
- lists, maps or nested rows.
-
- The schema should be a StructType.
-
- It is important that the schema matches the types of the objects
- in each row or exceptions could be thrown at runtime.
-
- >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
- >>> schema = StructType([StructField("field1", IntegerType(), False),
- ... StructField("field2", StringType(), False)])
- >>> df = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT * from table1")
- >>> df2.collect()
- [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
-
- >>> from datetime import date, datetime
- >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
- ... date(2010, 1, 1),
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3], None)])
- >>> schema = StructType([
- ... StructField("byte1", ByteType(), False),
- ... StructField("byte2", ByteType(), False),
- ... StructField("short1", ShortType(), False),
- ... StructField("short2", ShortType(), False),
- ... StructField("int", IntegerType(), False),
- ... StructField("float", FloatType(), False),
- ... StructField("date", DateType(), False),
- ... StructField("time", TimestampType(), False),
- ... StructField("map",
- ... MapType(StringType(), IntegerType(), False), False),
- ... StructField("struct",
- ... StructType([StructField("b", ShortType(), False)]), False),
- ... StructField("list", ArrayType(ByteType(), False), False),
- ... StructField("null", DoubleType(), True)])
- >>> df = sqlCtx.applySchema(rdd, schema)
- >>> results = df.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
- ... x.time, x.map["a"], x.struct.b, x.list, x.null))
- >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
- (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
- datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-
- >>> df.registerTempTable("table2")
- >>> sqlCtx.sql(
- ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
- ... "float + 1.5 as float FROM table2").collect()
- [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
-
- >>> rdd = sc.parallelize([(127, -32768, 1.0,
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3])])
- >>> abstract = "byte short float time map{} struct(b) list[]"
- >>> schema = _parse_schema_abstract(abstract)
- >>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> df = sqlCtx.applySchema(rdd, typedSchema)
- >>> df.collect()
- [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
- """
-
- if isinstance(rdd, DataFrame):
- raise TypeError("Cannot apply schema to DataFrame")
-
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
-
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
- rows = rdd.take(10)
-
- for row in rows:
- _verify_type(row, schema)
-
- # convert python objects to sql data
- converter = _python_to_sql_converter(schema)
- rdd = rdd.map(converter)
-
- jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return DataFrame(df, self)
-
- def registerRDDAsTable(self, rdd, tableName):
- """Registers the given RDD as a temporary table in the catalog.
-
- Temporary tables exist only during the lifetime of this instance of
- SQLContext.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
- """
- if (rdd.__class__ is DataFrame):
- df = rdd._jdf
- self._ssql_ctx.registerRDDAsTable(df, tableName)
- else:
- raise ValueError("Can only register DataFrame as table")
-
- def parquetFile(self, *paths):
- """Loads a Parquet file, returning the result as a L{DataFrame}.
-
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
- """
- gateway = self._sc._gateway
- jpath = paths[0]
- jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1)
- for i in range(1, len(paths)):
- jpaths[i] = paths[i]
- jdf = self._ssql_ctx.parquetFile(jpath, jpaths)
- return DataFrame(jdf, self)
-
- def jsonFile(self, path, schema=None, samplingRatio=1.0):
- """
- Loads a text file storing one JSON object per line as a
- L{DataFrame}.
-
- If the schema is provided, applies the given schema to this
- JSON dataset.
-
- Otherwise, it samples the dataset with ratio `samplingRatio` to
- determine the schema.
-
- >>> import tempfile, shutil
- >>> jsonFile = tempfile.mkdtemp()
- >>> shutil.rmtree(jsonFile)
- >>> ofn = open(jsonFile, 'w')
- >>> for json in jsonStrings:
- ... print>>ofn, json
- >>> ofn.close()
- >>> df1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
- >>> sqlCtx.registerRDDAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> schema = StructType([
- ... StructField("field2", StringType(), True),
- ... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
- """
- if schema is None:
- df = self._ssql_ctx.jsonFile(path, samplingRatio)
- else:
- scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- df = self._ssql_ctx.jsonFile(path, scala_datatype)
- return DataFrame(df, self)
-
- def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
- """Loads an RDD storing one JSON object per string as a L{DataFrame}.
-
- If the schema is provided, applies the given schema to this
- JSON dataset.
-
- Otherwise, it samples the dataset with ratio `samplingRatio` to
- determine the schema.
-
- >>> df1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
- >>> sqlCtx.registerRDDAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> schema = StructType([
- ... StructField("field2", StringType(), True),
- ... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
-
- >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
- >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
- """
-
- def func(iterator):
- for x in iterator:
- if not isinstance(x, basestring):
- x = unicode(x)
- if isinstance(x, unicode):
- x = x.encode("utf-8")
- yield x
- keyed = rdd.mapPartitions(func)
- keyed._bypass_serializer = True
- jrdd = keyed._jrdd.map(self._jvm.BytesToString())
- if schema is None:
- df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
- else:
- scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return DataFrame(df, self)
-
- def sql(self, sqlQuery):
- """Return a L{DataFrame} representing the result of the given query.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
- >>> df2.collect()
- [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
- """
- return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
-
- def table(self, tableName):
- """Returns the specified table as a L{DataFrame}.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
- >>> df2 = sqlCtx.table("table1")
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
- """
- return DataFrame(self._ssql_ctx.table(tableName), self)
-
- def cacheTable(self, tableName):
- """Caches the specified table in-memory."""
- self._ssql_ctx.cacheTable(tableName)
-
- def uncacheTable(self, tableName):
- """Removes the specified table from the in-memory cache."""
- self._ssql_ctx.uncacheTable(tableName)
-
-
-class HiveContext(SQLContext):
-
- """A variant of Spark SQL that integrates with data stored in Hive.
-
- Configuration for Hive is read from hive-site.xml on the classpath.
- It supports running both SQL and HiveQL commands.
- """
-
- def __init__(self, sparkContext, hiveContext=None):
- """Create a new HiveContext.
-
- :param sparkContext: The SparkContext to wrap.
- :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
- HiveContext in the JVM, instead we make all calls to this object.
- """
- SQLContext.__init__(self, sparkContext)
-
- if hiveContext:
- self._scala_HiveContext = hiveContext
-
- @property
- def _ssql_ctx(self):
- try:
- if not hasattr(self, '_scala_HiveContext'):
- self._scala_HiveContext = self._get_hive_ctx()
- return self._scala_HiveContext
- except Py4JError as e:
- raise Exception("You must build Spark with Hive. "
- "Export 'SPARK_HIVE=true' and run "
- "build/sbt assembly", e)
-
- def _get_hive_ctx(self):
- return self._jvm.HiveContext(self._jsc.sc())
-
-
-def _create_row(fields, values):
- row = Row(*values)
- row.__FIELDS__ = fields
- return row
-
-
-class Row(tuple):
-
- """
- A row in L{DataFrame}. The fields in it can be accessed like attributes.
-
- Row can be used to create a row object by using named arguments,
- the fields will be sorted by names.
-
- >>> row = Row(name="Alice", age=11)
- >>> row
- Row(age=11, name='Alice')
- >>> row.name, row.age
- ('Alice', 11)
-
- Row also can be used to create another Row like class, then it
- could be used to create Row objects, such as
-
- >>> Person = Row("name", "age")
- >>> Person
-
- >>> Person("Alice", 11)
- Row(name='Alice', age=11)
- """
-
- def __new__(self, *args, **kwargs):
- if args and kwargs:
- raise ValueError("Can not use both args "
- "and kwargs to create Row")
- if args:
- # create row class or objects
- return tuple.__new__(self, args)
-
- elif kwargs:
- # create row objects
- names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
- row.__FIELDS__ = names
- return row
-
- else:
- raise ValueError("No args or kwargs")
-
- def asDict(self):
- """
- Return as an dict
- """
- if not hasattr(self, "__FIELDS__"):
- raise TypeError("Cannot convert a Row class into dict")
- return dict(zip(self.__FIELDS__, self))
-
- # let obect acs like class
- def __call__(self, *args):
- """create new Row object"""
- return _create_row(self, args)
-
- def __getattr__(self, item):
- if item.startswith("__"):
- raise AttributeError(item)
- try:
- # it will be slow when it has many fields,
- # but this will not be used in normal cases
- idx = self.__FIELDS__.index(item)
- return self[idx]
- except IndexError:
- raise AttributeError(item)
-
- def __reduce__(self):
- if hasattr(self, "__FIELDS__"):
- return (_create_row, (self.__FIELDS__, tuple(self)))
- else:
- return tuple.__reduce__(self)
-
- def __repr__(self):
- if hasattr(self, "__FIELDS__"):
- return "Row(%s)" % ", ".join("%s=%r" % (k, v)
- for k, v in zip(self.__FIELDS__, self))
- else:
- return "" % ", ".join(self)
-
-
-class DataFrame(object):
-
- """A collection of rows that have the same columns.
-
- A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
- and can be created using various functions in :class:`SQLContext`::
-
- people = sqlContext.parquetFile("...")
-
- Once created, it can be manipulated using the various domain-specific-language
- (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
-
- To select a column from the data frame, use the apply method::
-
- ageCol = people.age
-
- Note that the :class:`Column` type can also be manipulated
- through its various functions::
-
- # The following creates a new column that increases everybody's age by 10.
- people.age + 10
-
-
- A more concrete example::
-
- # To create DataFrame using SQLContext
- people = sqlContext.parquetFile("...")
- department = sqlContext.parquetFile("...")
-
- people.filter(people.age > 30).join(department, people.deptId == department.id)) \
- .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
- """
-
- def __init__(self, jdf, sql_ctx):
- self._jdf = jdf
- self.sql_ctx = sql_ctx
- self._sc = sql_ctx and sql_ctx._sc
- self.is_cached = False
-
- @property
- def rdd(self):
- """
- Return the content of the :class:`DataFrame` as an :class:`RDD`
- of :class:`Row` s.
- """
- if not hasattr(self, '_lazy_rdd'):
- jrdd = self._jdf.javaToPython()
- rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
- schema = self.schema()
-
- def applySchema(it):
- cls = _create_cls(schema)
- return itertools.imap(cls, it)
-
- self._lazy_rdd = rdd.mapPartitions(applySchema)
-
- return self._lazy_rdd
-
- def toJSON(self, use_unicode=False):
- """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
-
- >>> df1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql( "SELECT * from table1")
- >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
- True
- >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
- >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
- True
- """
- rdd = self._jdf.toJSON()
- return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
-
- def saveAsParquetFile(self, path):
- """Save the contents as a Parquet file, preserving the schema.
-
- Files that are written out using this method can be read back in as
- a DataFrame using the L{SQLContext.parquetFile} method.
-
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(df2.collect()) == sorted(df.collect())
- True
- """
- self._jdf.saveAsParquetFile(path)
-
- def registerTempTable(self, name):
- """Registers this RDD as a temporary table using the given name.
-
- The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this DataFrame.
-
- >>> df.registerTempTable("people")
- >>> df2 = sqlCtx.sql("select * from people")
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
- """
- self._jdf.registerTempTable(name)
-
- def registerAsTable(self, name):
- """DEPRECATED: use registerTempTable() instead"""
- warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
- self.registerTempTable(name)
-
- def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this DataFrame into the specified table.
-
- Optionally overwriting any existing data.
- """
- self._jdf.insertInto(tableName, overwrite)
-
- def saveAsTable(self, tableName):
- """Creates a new table with the contents of this DataFrame."""
- self._jdf.saveAsTable(tableName)
-
- def schema(self):
- """Returns the schema of this DataFrame (represented by
- a L{StructType}).
-
- >>> df.schema()
- StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
- """
- return _parse_datatype_json_string(self._jdf.schema().json())
-
- def printSchema(self):
- """Prints out the schema in the tree format.
-
- >>> df.printSchema()
- root
- |-- age: integer (nullable = true)
- |-- name: string (nullable = true)
-
- """
- print (self._jdf.schema().treeString())
-
- def count(self):
- """Return the number of elements in this RDD.
-
- Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the DataFrame,
- which supports features such as filter pushdown.
-
- >>> df.count()
- 2L
- """
- return self._jdf.count()
-
- def collect(self):
- """Return a list that contains all of the rows.
-
- Each object in the list is a Row, the fields can be accessed as
- attributes.
-
- >>> df.collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
- """
- with SCCallSiteSync(self._sc) as css:
- bytesInJava = self._jdf.javaToPython().collect().iterator()
- tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
- tempFile.close()
- self._sc._writeToFile(bytesInJava, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
- os.unlink(tempFile.name)
- cls = _create_cls(self.schema())
- return [cls(r) for r in rs]
-
- def limit(self, num):
- """Limit the result count to the number specified.
-
- >>> df.limit(1).collect()
- [Row(age=2, name=u'Alice')]
- >>> df.limit(0).collect()
- []
- """
- jdf = self._jdf.limit(num)
- return DataFrame(jdf, self.sql_ctx)
-
- def take(self, num):
- """Take the first num rows of the RDD.
-
- Each object in the list is a Row, the fields can be accessed as
- attributes.
-
- >>> df.take(2)
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
- """
- return self.limit(num).collect()
-
- def map(self, f):
- """ Return a new RDD by applying a function to each Row, it's a
- shorthand for df.rdd.map()
-
- >>> df.map(lambda p: p.name).collect()
- [u'Alice', u'Bob']
- """
- return self.rdd.map(f)
-
- def mapPartitions(self, f, preservesPartitioning=False):
- """
- Return a new RDD by applying a function to each partition.
-
- >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
- >>> def f(iterator): yield 1
- >>> rdd.mapPartitions(f).sum()
- 4
- """
- return self.rdd.mapPartitions(f, preservesPartitioning)
-
- def cache(self):
- """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
- """
- self.is_cached = True
- self._jdf.cache()
- return self
-
- def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
- """ Set the storage level to persist its values across operations
- after the first time it is computed. This can only be used to assign
- a new storage level if the RDD does not have a storage level set yet.
- If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
- """
- self.is_cached = True
- javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
- self._jdf.persist(javaStorageLevel)
- return self
-
- def unpersist(self, blocking=True):
- """ Mark it as non-persistent, and remove all blocks for it from
- memory and disk.
- """
- self.is_cached = False
- self._jdf.unpersist(blocking)
- return self
-
- # def coalesce(self, numPartitions, shuffle=False):
- # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
- # return DataFrame(rdd, self.sql_ctx)
-
- def repartition(self, numPartitions):
- """ Return a new :class:`DataFrame` that has exactly `numPartitions`
- partitions.
- """
- rdd = self._jdf.repartition(numPartitions, None)
- return DataFrame(rdd, self.sql_ctx)
-
- def sample(self, withReplacement, fraction, seed=None):
- """
- Return a sampled subset of this DataFrame.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.sample(False, 0.5, 97).count()
- 2L
- """
- assert fraction >= 0.0, "Negative fraction value: %s" % fraction
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- rdd = self._jdf.sample(withReplacement, fraction, long(seed))
- return DataFrame(rdd, self.sql_ctx)
-
- # def takeSample(self, withReplacement, num, seed=None):
- # """Return a fixed-size sampled subset of this DataFrame.
- #
- # >>> df = sqlCtx.inferSchema(rdd)
- # >>> df.takeSample(False, 2, 97)
- # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
- # """
- # seed = seed if seed is not None else random.randint(0, sys.maxint)
- # with SCCallSiteSync(self.context) as css:
- # bytesInJava = self._jdf \
- # .takeSampleToPython(withReplacement, num, long(seed)) \
- # .iterator()
- # cls = _create_cls(self.schema())
- # return map(cls, self._collect_iterator_through_file(bytesInJava))
-
- @property
- def dtypes(self):
- """Return all column names and their data types as a list.
-
- >>> df.dtypes
- [('age', 'integer'), ('name', 'string')]
- """
- return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
-
- @property
- def columns(self):
- """ Return all column names as a list.
-
- >>> df.columns
- [u'age', u'name']
- """
- return [f.name for f in self.schema().fields]
-
- def join(self, other, joinExprs=None, joinType=None):
- """
- Join with another DataFrame, using the given join expression.
- The following performs a full outer join between `df1` and `df2`::
-
- :param other: Right side of the join
- :param joinExprs: Join expression
- :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
-
- >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
- [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
- """
-
- if joinExprs is None:
- jdf = self._jdf.join(other._jdf)
- else:
- assert isinstance(joinExprs, Column), "joinExprs should be Column"
- if joinType is None:
- jdf = self._jdf.join(other._jdf, joinExprs._jc)
- else:
- assert isinstance(joinType, basestring), "joinType should be basestring"
- jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
- return DataFrame(jdf, self.sql_ctx)
-
- def sort(self, *cols):
- """ Return a new :class:`DataFrame` sorted by the specified column.
-
- :param cols: The columns or expressions used for sorting
-
- >>> df.sort(df.age.desc()).collect()
- [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
- >>> df.sortBy(df.age.desc()).collect()
- [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
- """
- if not cols:
- raise ValueError("should sort by at least one column")
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
- return DataFrame(jdf, self.sql_ctx)
-
- sortBy = sort
-
- def head(self, n=None):
- """ Return the first `n` rows or the first row if n is None.
-
- >>> df.head()
- Row(age=2, name=u'Alice')
- >>> df.head(1)
- [Row(age=2, name=u'Alice')]
- """
- if n is None:
- rs = self.head(1)
- return rs[0] if rs else None
- return self.take(n)
-
- def first(self):
- """ Return the first row.
-
- >>> df.first()
- Row(age=2, name=u'Alice')
- """
- return self.head()
-
- def __getitem__(self, item):
- """ Return the column by given name
-
- >>> df['age'].collect()
- [Row(age=2), Row(age=5)]
- >>> df[ ["name", "age"]].collect()
- [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
- >>> df[ df.age > 3 ].collect()
- [Row(age=5, name=u'Bob')]
- """
- if isinstance(item, basestring):
- jc = self._jdf.apply(item)
- return Column(jc, self.sql_ctx)
- elif isinstance(item, Column):
- return self.filter(item)
- elif isinstance(item, list):
- return self.select(*item)
- else:
- raise IndexError("unexpected index: %s" % item)
-
- def __getattr__(self, name):
- """ Return the column by given name
-
- >>> df.age.collect()
- [Row(age=2), Row(age=5)]
- """
- if name.startswith("__"):
- raise AttributeError(name)
- jc = self._jdf.apply(name)
- return Column(jc, self.sql_ctx)
-
- def select(self, *cols):
- """ Selecting a set of expressions.
-
- >>> df.select().collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
- >>> df.select('*').collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
- >>> df.select('name', 'age').collect()
- [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
- >>> df.select(df.name, (df.age + 10).alias('age')).collect()
- [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
- """
- if not cols:
- cols = ["*"]
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
- return DataFrame(jdf, self.sql_ctx)
-
- def selectExpr(self, *expr):
- """
- Selects a set of SQL expressions. This is a variant of
- `select` that accepts SQL expressions.
-
- >>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
- """
- jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
- jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
- return DataFrame(jdf, self.sql_ctx)
-
- def filter(self, condition):
- """ Filtering rows using the given condition, which could be
- Column expression or string of SQL expression.
-
- where() is an alias for filter().
-
- >>> df.filter(df.age > 3).collect()
- [Row(age=5, name=u'Bob')]
- >>> df.where(df.age == 2).collect()
- [Row(age=2, name=u'Alice')]
-
- >>> df.filter("age > 3").collect()
- [Row(age=5, name=u'Bob')]
- >>> df.where("age = 2").collect()
- [Row(age=2, name=u'Alice')]
- """
- if isinstance(condition, basestring):
- jdf = self._jdf.filter(condition)
- elif isinstance(condition, Column):
- jdf = self._jdf.filter(condition._jc)
- else:
- raise TypeError("condition should be string or Column")
- return DataFrame(jdf, self.sql_ctx)
-
- where = filter
-
- def groupBy(self, *cols):
- """ Group the :class:`DataFrame` using the specified columns,
- so we can run aggregation on them. See :class:`GroupedData`
- for all the available aggregate functions.
-
- >>> df.groupBy().avg().collect()
- [Row(AVG(age#0)=3.5)]
- >>> df.groupBy('name').agg({'age': 'mean'}).collect()
- [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
- >>> df.groupBy(df.name).avg().collect()
- [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
- """
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
- return GroupedData(jdf, self.sql_ctx)
-
- def agg(self, *exprs):
- """ Aggregate on the entire :class:`DataFrame` without groups
- (shorthand for df.groupBy.agg()).
-
- >>> df.agg({"age": "max"}).collect()
- [Row(MAX(age#0)=5)]
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.min(df.age)).collect()
- [Row(MIN(age#0)=2)]
- """
- return self.groupBy().agg(*exprs)
-
- def unionAll(self, other):
- """ Return a new DataFrame containing union of rows in this
- frame and another frame.
-
- This is equivalent to `UNION ALL` in SQL.
- """
- return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
-
- def intersect(self, other):
- """ Return a new :class:`DataFrame` containing rows only in
- both this frame and another frame.
-
- This is equivalent to `INTERSECT` in SQL.
- """
- return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
-
- def subtract(self, other):
- """ Return a new :class:`DataFrame` containing rows in this frame
- but not in another frame.
-
- This is equivalent to `EXCEPT` in SQL.
- """
- return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
-
- def addColumn(self, colName, col):
- """ Return a new :class:`DataFrame` by adding a column.
-
- >>> df.addColumn('age2', df.age + 2).collect()
- [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
- """
- return self.select('*', col.alias(colName))
-
- def to_pandas(self):
- """
- Collect all the rows and return a `pandas.DataFrame`.
-
- >>> df.to_pandas() # doctest: +SKIP
- age name
- 0 2 Alice
- 1 5 Bob
- """
- import pandas as pd
- return pd.DataFrame.from_records(self.collect(), columns=self.columns)
-
-
-# Having SchemaRDD for backward compatibility (for docs)
-class SchemaRDD(DataFrame):
- """
- SchemaRDD is deprecated, please use DataFrame
- """
-
-
-def dfapi(f):
- def _api(self):
- name = f.__name__
- jdf = getattr(self._jdf, name)()
- return DataFrame(jdf, self.sql_ctx)
- _api.__name__ = f.__name__
- _api.__doc__ = f.__doc__
- return _api
-
-
-class GroupedData(object):
-
- """
- A set of methods for aggregations on a :class:`DataFrame`,
- created by DataFrame.groupBy().
- """
-
- def __init__(self, jdf, sql_ctx):
- self._jdf = jdf
- self.sql_ctx = sql_ctx
-
- def agg(self, *exprs):
- """ Compute aggregates by specifying a map from column name
- to aggregate methods.
-
- The available aggregate methods are `avg`, `max`, `min`,
- `sum`, `count`.
-
- :param exprs: list or aggregate columns or a map from column
- name to aggregate methods.
-
- >>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"age": "max"}).collect()
- [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
- >>> from pyspark.sql import Dsl
- >>> gdf.agg(Dsl.min(df.age)).collect()
- [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
- """
- assert exprs, "exprs should not be empty"
- if len(exprs) == 1 and isinstance(exprs[0], dict):
- jmap = MapConverter().convert(exprs[0],
- self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(jmap)
- else:
- # Columns
- assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
- jcols = ListConverter().convert([c._jc for c in exprs[1:]],
- self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
- return DataFrame(jdf, self.sql_ctx)
-
- @dfapi
- def count(self):
- """ Count the number of rows for each group.
-
- >>> df.groupBy(df.age).count().collect()
- [Row(age=2, count=1), Row(age=5, count=1)]
- """
-
- @dfapi
- def mean(self):
- """Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
-
- @dfapi
- def avg(self):
- """Compute the average value for each numeric columns
- for each group."""
-
- @dfapi
- def max(self):
- """Compute the max value for each numeric columns for
- each group. """
-
- @dfapi
- def min(self):
- """Compute the min value for each numeric column for
- each group."""
-
- @dfapi
- def sum(self):
- """Compute the sum for each numeric columns for each
- group."""
-
-
-def _create_column_from_literal(literal):
- sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.lit(literal)
-
-
-def _create_column_from_name(name):
- sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.col(name)
-
-
-def _to_java_column(col):
- if isinstance(col, Column):
- jcol = col._jc
- else:
- jcol = _create_column_from_name(col)
- return jcol
-
-
-def _unary_op(name, doc="unary operator"):
- """ Create a method for given unary operator """
- def _(self):
- jc = getattr(self._jc, name)()
- return Column(jc, self.sql_ctx)
- _.__doc__ = doc
- return _
-
-
-def _dsl_op(name, doc=''):
- def _(self):
- jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
- return Column(jc, self.sql_ctx)
- _.__doc__ = doc
- return _
-
-
-def _bin_op(name, doc="binary operator"):
- """ Create a method for given binary operator
- """
- def _(self, other):
- jc = other._jc if isinstance(other, Column) else other
- njc = getattr(self._jc, name)(jc)
- return Column(njc, self.sql_ctx)
- _.__doc__ = doc
- return _
-
-
-def _reverse_op(name, doc="binary operator"):
- """ Create a method for binary operator (this object is on right side)
- """
- def _(self, other):
- jother = _create_column_from_literal(other)
- jc = getattr(jother, name)(self._jc)
- return Column(jc, self.sql_ctx)
- _.__doc__ = doc
- return _
-
-
-class Column(DataFrame):
-
- """
- A column in a DataFrame.
-
- `Column` instances can be created by::
-
- # 1. Select a column out of a DataFrame
- df.colName
- df["colName"]
-
- # 2. Create from an expression
- df.colName + 1
- 1 / df.colName
- """
-
- def __init__(self, jc, sql_ctx=None):
- self._jc = jc
- super(Column, self).__init__(jc, sql_ctx)
-
- # arithmetic operators
- __neg__ = _dsl_op("negate")
- __add__ = _bin_op("plus")
- __sub__ = _bin_op("minus")
- __mul__ = _bin_op("multiply")
- __div__ = _bin_op("divide")
- __mod__ = _bin_op("mod")
- __radd__ = _bin_op("plus")
- __rsub__ = _reverse_op("minus")
- __rmul__ = _bin_op("multiply")
- __rdiv__ = _reverse_op("divide")
- __rmod__ = _reverse_op("mod")
-
- # logistic operators
- __eq__ = _bin_op("equalTo")
- __ne__ = _bin_op("notEqual")
- __lt__ = _bin_op("lt")
- __le__ = _bin_op("leq")
- __ge__ = _bin_op("geq")
- __gt__ = _bin_op("gt")
-
- # `and`, `or`, `not` cannot be overloaded in Python,
- # so use bitwise operators as boolean operators
- __and__ = _bin_op('and')
- __or__ = _bin_op('or')
- __invert__ = _dsl_op('not')
- __rand__ = _bin_op("and")
- __ror__ = _bin_op("or")
-
- # container operators
- __contains__ = _bin_op("contains")
- __getitem__ = _bin_op("getItem")
- getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
-
- # string methods
- rlike = _bin_op("rlike")
- like = _bin_op("like")
- startswith = _bin_op("startsWith")
- endswith = _bin_op("endsWith")
-
- def substr(self, startPos, length):
- """
- Return a Column which is a substring of the column
-
- :param startPos: start position (int or Column)
- :param length: length of the substring (int or Column)
-
- >>> df.name.substr(1, 3).collect()
- [Row(col=u'Ali'), Row(col=u'Bob')]
- """
- if type(startPos) != type(length):
- raise TypeError("Can not mix the type")
- if isinstance(startPos, (int, long)):
- jc = self._jc.substr(startPos, length)
- elif isinstance(startPos, Column):
- jc = self._jc.substr(startPos._jc, length._jc)
- else:
- raise TypeError("Unexpected type: %s" % type(startPos))
- return Column(jc, self.sql_ctx)
-
- __getslice__ = substr
-
- # order
- asc = _unary_op("asc")
- desc = _unary_op("desc")
-
- isNull = _unary_op("isNull", "True if the current expression is null.")
- isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
-
- def alias(self, alias):
- """Return a alias for this column
-
- >>> df.age.alias("age2").collect()
- [Row(age2=2), Row(age2=5)]
- """
- return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
-
- def cast(self, dataType):
- """ Convert the column into type `dataType`
-
- >>> df.select(df.age.cast("string").alias('ages')).collect()
- [Row(ages=u'2'), Row(ages=u'5')]
- >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
- [Row(ages=u'2'), Row(ages=u'5')]
- """
- if self.sql_ctx is None:
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- else:
- ssql_ctx = self.sql_ctx._ssql_ctx
- if isinstance(dataType, basestring):
- jc = self._jc.cast(dataType)
- elif isinstance(dataType, DataType):
- jdt = ssql_ctx.parseDataType(dataType.json())
- jc = self._jc.cast(jdt)
- return Column(jc, self.sql_ctx)
-
- def to_pandas(self):
- """
- Return a pandas.Series from the column
-
- >>> df.age.to_pandas() # doctest: +SKIP
- 0 2
- 1 5
- dtype: int64
- """
- import pandas as pd
- data = [c for c, in self.collect()]
- return pd.Series(data)
-
-
-def _aggregate_func(name, doc=""):
- """ Create a function for aggregator by name"""
- def _(col):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
- return Column(jc)
- _.__name__ = name
- _.__doc__ = doc
- return staticmethod(_)
-
-
-class UserDefinedFunction(object):
- def __init__(self, func, returnType):
- self.func = func
- self.returnType = returnType
- self._broadcast = None
- self._judf = self._create_judf()
-
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
- sc._javaAccumulator, jdt)
- return judf
-
- def __del__(self):
- if self._broadcast is not None:
- self._broadcast.unpersist()
- self._broadcast = None
-
- def __call__(self, *cols):
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
-
-class Dsl(object):
- """
- A collections of builtin aggregators
- """
- DSLS = {
- 'lit': 'Creates a :class:`Column` of literal value.',
- 'col': 'Returns a :class:`Column` based on the given column name.',
- 'column': 'Returns a :class:`Column` based on the given column name.',
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
- 'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
-
- 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
- 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
- 'count': 'Aggregate function: returns the number of items in a group.',
- 'sum': 'Aggregate function: returns the sum of all values in the expression.',
- 'avg': 'Aggregate function: returns the average of the values in a group.',
- 'mean': 'Aggregate function: returns the average of the values in a group.',
- 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
- }
-
- for _name, _doc in DSLS.items():
- locals()[_name] = _aggregate_func(_name, _doc)
- del _name, _doc
-
- @staticmethod
- def countDistinct(col, *cols):
- """ Return a new Column for distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
- [Row(c=2)]
-
- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
- @staticmethod
- def approxCountDistinct(col, rsd=None):
- """ Return a new Column for approxiate distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
- else:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
- return Column(jc)
-
- @staticmethod
- def udf(f, returnType=StringType()):
- """Create a user defined function (UDF)
-
- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
-
-
-def _test():
- import doctest
- from pyspark.context import SparkContext
- # let doctest run in pyspark.sql, so DataTypes can be picklable
- import pyspark.sql
- from pyspark.sql import Row, SQLContext
- from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
- globs = pyspark.sql.__dict__.copy()
- sc = SparkContext('local[4]', 'PythonTest')
- globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- globs['rdd'] = sc.parallelize(
- [Row(field1=1, field2="row1"),
- Row(field1=2, field2="row2"),
- Row(field1=3, field2="row3")]
- )
- rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
- rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
- globs['df'] = sqlCtx.inferSchema(rdd2)
- globs['df2'] = sqlCtx.inferSchema(rdd3)
- globs['ExamplePoint'] = ExamplePoint
- globs['ExamplePointUDT'] = ExamplePointUDT
- jsonStrings = [
- '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
- '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
- '"field6":[{"field7": "row2"}]}',
- '{"field1" : null, "field2": "row3", '
- '"field3":{"field4":33, "field5": []}}'
- ]
- globs['jsonStrings'] = jsonStrings
- globs['json'] = sc.parallelize(jsonStrings)
- (failure_count, test_count) = doctest.testmod(
- pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
- if failure_count:
- exit(-1)
-
-
-if __name__ == "__main__":
- _test()
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
new file mode 100644
index 0000000000000..0a5ba00393aab
--- /dev/null
+++ b/python/pyspark/sql/__init__.py
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+public classes of Spark SQL:
+
+ - L{SQLContext}
+ Main entry point for SQL functionality.
+ - L{DataFrame}
+ A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
+ addition to normal RDD operations, DataFrames also support SQL.
+ - L{GroupedData}
+ - L{Column}
+ Column is a DataFrame with a single column.
+ - L{Row}
+ A Row of data returned by a Spark SQL query.
+ - L{HiveContext}
+ Main entry point for accessing data stored in Apache Hive..
+"""
+
+from pyspark.sql.context import SQLContext, HiveContext
+from pyspark.sql.types import Row
+from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
+
+__all__ = [
+ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
+ 'Dsl',
+]
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
new file mode 100644
index 0000000000000..49f016a9cf2e9
--- /dev/null
+++ b/python/pyspark/sql/context.py
@@ -0,0 +1,642 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import warnings
+import json
+from array import array
+from itertools import imap
+
+from py4j.protocol import Py4JError
+
+from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.sql.types import StringType, StructType, _verify_type, \
+ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
+from pyspark.sql.dataframe import DataFrame
+
+__all__ = ["SQLContext", "HiveContext"]
+
+
+class SQLContext(object):
+
+ """Main entry point for Spark SQL functionality.
+
+ A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
+ tables, execute SQL over tables, cache tables, and read parquet files.
+ """
+
+ def __init__(self, sparkContext, sqlContext=None):
+ """Create a new SQLContext.
+
+ :param sparkContext: The SparkContext to wrap.
+ :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
+ SQLContext in the JVM, instead we make all calls to this object.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ TypeError:...
+
+ >>> bad_rdd = sc.parallelize([1,2,3])
+ >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+
+ >>> from datetime import datetime
+ >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
+ ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
+ ... time=datetime(2014, 8, 1, 14, 1, 5))])
+ >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df.registerTempTable("allTypes")
+ >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
+ ... 'from allTypes where b and i > 0').collect()
+ [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+ ... x.row.a, x.list)).collect()
+ [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
+ """
+ self._sc = sparkContext
+ self._jsc = self._sc._jsc
+ self._jvm = self._sc._jvm
+ self._scala_SQLContext = sqlContext
+
+ @property
+ def _ssql_ctx(self):
+ """Accessor for the JVM Spark SQL context.
+
+ Subclasses can override this property to provide their own
+ JVM Contexts.
+ """
+ if self._scala_SQLContext is None:
+ self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
+ return self._scala_SQLContext
+
+ def registerFunction(self, name, f, returnType=StringType()):
+ """Registers a lambda function as a UDF so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not given it default to a string and conversion will automatically
+ be done. For any other return type, the produced object must match the specified type.
+
+ >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
+ >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+ [Row(c0=u'4')]
+ >>> from pyspark.sql.types import IntegerType
+ >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ [Row(c0=4)]
+ """
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
+ self._ssql_ctx.udf().registerPython(name,
+ bytearray(pickled_cmd),
+ env,
+ includes,
+ self._sc.pythonExec,
+ bvars,
+ self._sc._javaAccumulator,
+ returnType.json())
+
+ def inferSchema(self, rdd, samplingRatio=None):
+ """Infer and apply a schema to an RDD of L{Row}.
+
+ When samplingRatio is specified, the schema is inferred by looking
+ at the types of each row in the sampled dataset. Otherwise, the
+ first 100 rows of the RDD are inspected. Nested collections are
+ supported, which can include array, dict, list, Row, tuple,
+ namedtuple, or object.
+
+ Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+ Using top level dicts is deprecated, as dict is used to represent Maps.
+
+ If a single column has multiple distinct inferred types, it may cause
+ runtime exceptions.
+
+ >>> rdd = sc.parallelize(
+ ... [Row(field1=1, field2="row1"),
+ ... Row(field1=2, field2="row2"),
+ ... Row(field1=3, field2="row3")])
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
+ Row(field1=1, field2=u'row1')
+
+ >>> NestedRow = Row("f1", "f2")
+ >>> nestedRdd1 = sc.parallelize([
+ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
+ ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
+ >>> df = sqlCtx.inferSchema(nestedRdd1)
+ >>> df.collect()
+ [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
+
+ >>> nestedRdd2 = sc.parallelize([
+ ... NestedRow([[1, 2], [2, 3]], [1, 2]),
+ ... NestedRow([[2, 3], [3, 4]], [2, 3])])
+ >>> df = sqlCtx.inferSchema(nestedRdd2)
+ >>> df.collect()
+ [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
+
+ >>> from collections import namedtuple
+ >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
+ >>> rdd = sc.parallelize(
+ ... [CustomRow(field1=1, field2="row1"),
+ ... CustomRow(field1=2, field2="row2"),
+ ... CustomRow(field1=3, field2="row3")])
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
+ Row(field1=1, field2=u'row1')
+ """
+
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
+
+ first = rdd.first()
+ if not first:
+ raise ValueError("The first row in RDD is empty, "
+ "can not infer schema")
+ if type(first) is dict:
+ warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+ "please use pyspark.sql.Row instead")
+
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ warnings.warn("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio > 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+ converter = _create_converter(schema)
+ rdd = rdd.map(converter)
+ return self.applySchema(rdd, schema)
+
+ def applySchema(self, rdd, schema):
+ """
+ Applies the given schema to the given RDD of L{tuple} or L{list}.
+
+ These tuples or lists can contain complex nested structures like
+ lists, maps or nested rows.
+
+ The schema should be a StructType.
+
+ It is important that the schema matches the types of the objects
+ in each row or exceptions could be thrown at runtime.
+
+ >>> from pyspark.sql.types import *
+ >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
+ >>> schema = StructType([StructField("field1", IntegerType(), False),
+ ... StructField("field2", StringType(), False)])
+ >>> df = sqlCtx.applySchema(rdd2, schema)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT * from table1")
+ >>> df2.collect()
+ [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
+
+ >>> from datetime import date, datetime
+ >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ ... date(2010, 1, 1),
+ ... datetime(2010, 1, 1, 1, 1, 1),
+ ... {"a": 1}, (2,), [1, 2, 3], None)])
+ >>> schema = StructType([
+ ... StructField("byte1", ByteType(), False),
+ ... StructField("byte2", ByteType(), False),
+ ... StructField("short1", ShortType(), False),
+ ... StructField("short2", ShortType(), False),
+ ... StructField("int", IntegerType(), False),
+ ... StructField("float", FloatType(), False),
+ ... StructField("date", DateType(), False),
+ ... StructField("time", TimestampType(), False),
+ ... StructField("map",
+ ... MapType(StringType(), IntegerType(), False), False),
+ ... StructField("struct",
+ ... StructType([StructField("b", ShortType(), False)]), False),
+ ... StructField("list", ArrayType(ByteType(), False), False),
+ ... StructField("null", DoubleType(), True)])
+ >>> df = sqlCtx.applySchema(rdd, schema)
+ >>> results = df.map(
+ ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
+ ... x.time, x.map["a"], x.struct.b, x.list, x.null))
+ >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
+ (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
+ datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+
+ >>> df.registerTempTable("table2")
+ >>> sqlCtx.sql(
+ ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
+ ... "float + 1.5 as float FROM table2").collect()
+ [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
+
+ >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+ >>> rdd = sc.parallelize([(127, -32768, 1.0,
+ ... datetime(2010, 1, 1, 1, 1, 1),
+ ... {"a": 1}, (2,), [1, 2, 3])])
+ >>> abstract = "byte short float time map{} struct(b) list[]"
+ >>> schema = _parse_schema_abstract(abstract)
+ >>> typedSchema = _infer_schema_type(rdd.first(), schema)
+ >>> df = sqlCtx.applySchema(rdd, typedSchema)
+ >>> df.collect()
+ [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
+ """
+
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
+
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+
+ # take the first few rows to verify schema
+ rows = rdd.take(10)
+ # Row() cannot been deserialized by Pyrolite
+ if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+ rdd = rdd.map(tuple)
+ rows = rdd.take(10)
+
+ for row in rows:
+ _verify_type(row, schema)
+
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
+ jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
+
+ def registerRDDAsTable(self, rdd, tableName):
+ """Registers the given RDD as a temporary table in the catalog.
+
+ Temporary tables exist only during the lifetime of this instance of
+ SQLContext.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ """
+ if (rdd.__class__ is DataFrame):
+ df = rdd._jdf
+ self._ssql_ctx.registerRDDAsTable(df, tableName)
+ else:
+ raise ValueError("Can only register DataFrame as table")
+
+ def parquetFile(self, *paths):
+ """Loads a Parquet file, returning the result as a L{DataFrame}.
+
+ >>> import tempfile, shutil
+ >>> parquetFile = tempfile.mkdtemp()
+ >>> shutil.rmtree(parquetFile)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df.collect()) == sorted(df2.collect())
+ True
+ """
+ gateway = self._sc._gateway
+ jpath = paths[0]
+ jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1)
+ for i in range(1, len(paths)):
+ jpaths[i] = paths[i]
+ jdf = self._ssql_ctx.parquetFile(jpath, jpaths)
+ return DataFrame(jdf, self)
+
+ def jsonFile(self, path, schema=None, samplingRatio=1.0):
+ """
+ Loads a text file storing one JSON object per line as a
+ L{DataFrame}.
+
+ If the schema is provided, applies the given schema to this
+ JSON dataset.
+
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
+
+ >>> import tempfile, shutil
+ >>> jsonFile = tempfile.mkdtemp()
+ >>> shutil.rmtree(jsonFile)
+ >>> ofn = open(jsonFile, 'w')
+ >>> for json in jsonStrings:
+ ... print>>ofn, json
+ >>> ofn.close()
+ >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+ ... "field6 as f4 from table1")
+ >>> for r in df2.collect():
+ ... print r
+ Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+ Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
+ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+ ... "field6 as f4 from table2")
+ >>> for r in df4.collect():
+ ... print r
+ Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+ Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
+ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5",
+ ... ArrayType(IntegerType(), False), True)]), False)])
+ >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, "
+ ... "field3.field5[0] as f3 from table3")
+ >>> df6.collect()
+ [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
+ """
+ if schema is None:
+ df = self._ssql_ctx.jsonFile(path, samplingRatio)
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return DataFrame(df, self)
+
+ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
+ """Loads an RDD storing one JSON object per string as a L{DataFrame}.
+
+ If the schema is provided, applies the given schema to this
+ JSON dataset.
+
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
+
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+ ... "field6 as f4 from table1")
+ >>> for r in df2.collect():
+ ... print r
+ Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+ Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
+ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+ ... "field6 as f4 from table2")
+ >>> for r in df4.collect():
+ ... print r
+ Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+ Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
+ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5",
+ ... ArrayType(IntegerType(), False), True)]), False)])
+ >>> df5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, "
+ ... "field3.field5[0] as f3 from table3")
+ >>> df6.collect()
+ [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
+
+ >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
+ ... '{"key0": {"key1": "value1"}}'])).collect()
+ [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
+ ... '{"key0": {"key1": "value1"}}'])).collect()
+ [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ """
+
+ def func(iterator):
+ for x in iterator:
+ if not isinstance(x, basestring):
+ x = unicode(x)
+ if isinstance(x, unicode):
+ x = x.encode("utf-8")
+ yield x
+ keyed = rdd.mapPartitions(func)
+ keyed._bypass_serializer = True
+ jrdd = keyed._jrdd.map(self._jvm.BytesToString())
+ if schema is None:
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return DataFrame(df, self)
+
+ def sql(self, sqlQuery):
+ """Return a L{DataFrame} representing the result of the given query.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
+ [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
+ """
+ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
+
+ def table(self, tableName):
+ """Returns the specified table as a L{DataFrame}.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
+ True
+ """
+ return DataFrame(self._ssql_ctx.table(tableName), self)
+
+ def cacheTable(self, tableName):
+ """Caches the specified table in-memory."""
+ self._ssql_ctx.cacheTable(tableName)
+
+ def uncacheTable(self, tableName):
+ """Removes the specified table from the in-memory cache."""
+ self._ssql_ctx.uncacheTable(tableName)
+
+
+class HiveContext(SQLContext):
+
+ """A variant of Spark SQL that integrates with data stored in Hive.
+
+ Configuration for Hive is read from hive-site.xml on the classpath.
+ It supports running both SQL and HiveQL commands.
+ """
+
+ def __init__(self, sparkContext, hiveContext=None):
+ """Create a new HiveContext.
+
+ :param sparkContext: The SparkContext to wrap.
+ :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
+ HiveContext in the JVM, instead we make all calls to this object.
+ """
+ SQLContext.__init__(self, sparkContext)
+
+ if hiveContext:
+ self._scala_HiveContext = hiveContext
+
+ @property
+ def _ssql_ctx(self):
+ try:
+ if not hasattr(self, '_scala_HiveContext'):
+ self._scala_HiveContext = self._get_hive_ctx()
+ return self._scala_HiveContext
+ except Py4JError as e:
+ raise Exception("You must build Spark with Hive. "
+ "Export 'SPARK_HIVE=true' and run "
+ "build/sbt assembly", e)
+
+ def _get_hive_ctx(self):
+ return self._jvm.HiveContext(self._jsc.sc())
+
+
+def _create_row(fields, values):
+ row = Row(*values)
+ row.__FIELDS__ = fields
+ return row
+
+
+class Row(tuple):
+
+ """
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
+
+ Row can be used to create a row object by using named arguments,
+ the fields will be sorted by names.
+
+ >>> row = Row(name="Alice", age=11)
+ >>> row
+ Row(age=11, name='Alice')
+ >>> row.name, row.age
+ ('Alice', 11)
+
+ Row also can be used to create another Row like class, then it
+ could be used to create Row objects, such as
+
+ >>> Person = Row("name", "age")
+ >>> Person
+
+ >>> Person("Alice", 11)
+ Row(name='Alice', age=11)
+ """
+
+ def __new__(self, *args, **kwargs):
+ if args and kwargs:
+ raise ValueError("Can not use both args "
+ "and kwargs to create Row")
+ if args:
+ # create row class or objects
+ return tuple.__new__(self, args)
+
+ elif kwargs:
+ # create row objects
+ names = sorted(kwargs.keys())
+ values = tuple(kwargs[n] for n in names)
+ row = tuple.__new__(self, values)
+ row.__FIELDS__ = names
+ return row
+
+ else:
+ raise ValueError("No args or kwargs")
+
+ def asDict(self):
+ """
+ Return as an dict
+ """
+ if not hasattr(self, "__FIELDS__"):
+ raise TypeError("Cannot convert a Row class into dict")
+ return dict(zip(self.__FIELDS__, self))
+
+ # let obect acs like class
+ def __call__(self, *args):
+ """create new Row object"""
+ return _create_row(self, args)
+
+ def __getattr__(self, item):
+ if item.startswith("__"):
+ raise AttributeError(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self.__FIELDS__.index(item)
+ return self[idx]
+ except IndexError:
+ raise AttributeError(item)
+
+ def __reduce__(self):
+ if hasattr(self, "__FIELDS__"):
+ return (_create_row, (self.__FIELDS__, tuple(self)))
+ else:
+ return tuple.__reduce__(self)
+
+ def __repr__(self):
+ if hasattr(self, "__FIELDS__"):
+ return "Row(%s)" % ", ".join("%s=%r" % (k, v)
+ for k, v in zip(self.__FIELDS__, self))
+ else:
+ return "" % ", ".join(self)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.context
+ globs = pyspark.sql.context.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['rdd'] = sc.parallelize(
+ [Row(field1=1, field2="row1"),
+ Row(field1=2, field2="row2"),
+ Row(field1=3, field2="row3")]
+ )
+ jsonStrings = [
+ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
+ '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
+ '"field6":[{"field7": "row2"}]}',
+ '{"field1" : null, "field2": "row3", '
+ '"field3":{"field4":33, "field5": []}}'
+ ]
+ globs['jsonStrings'] = jsonStrings
+ globs['json'] = sc.parallelize(jsonStrings)
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
new file mode 100644
index 0000000000000..cda704eea75f5
--- /dev/null
+++ b/python/pyspark/sql/dataframe.py
@@ -0,0 +1,974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import itertools
+import warnings
+import random
+import os
+from tempfile import NamedTemporaryFile
+from itertools import imap
+
+from py4j.java_collections import ListConverter, MapConverter
+
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD, _prepare_for_python_RDD
+from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
+ UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.types import *
+from pyspark.sql.types import _create_cls, _parse_datatype_json_string
+
+
+__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+
+
+class DataFrame(object):
+
+ """A collection of rows that have the same columns.
+
+ A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+ and can be created using various functions in :class:`SQLContext`::
+
+ people = sqlContext.parquetFile("...")
+
+ Once created, it can be manipulated using the various domain-specific-language
+ (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
+
+ To select a column from the data frame, use the apply method::
+
+ ageCol = people.age
+
+ Note that the :class:`Column` type can also be manipulated
+ through its various functions::
+
+ # The following creates a new column that increases everybody's age by 10.
+ people.age + 10
+
+
+ A more concrete example::
+
+ # To create DataFrame using SQLContext
+ people = sqlContext.parquetFile("...")
+ department = sqlContext.parquetFile("...")
+
+ people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+ self._sc = sql_ctx and sql_ctx._sc
+ self.is_cached = False
+
+ @property
+ def rdd(self):
+ """
+ Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row` s.
+ """
+ if not hasattr(self, '_lazy_rdd'):
+ jrdd = self._jdf.javaToPython()
+ rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+ schema = self.schema()
+
+ def applySchema(it):
+ cls = _create_cls(schema)
+ return itertools.imap(cls, it)
+
+ self._lazy_rdd = rdd.mapPartitions(applySchema)
+
+ return self._lazy_rdd
+
+ def toJSON(self, use_unicode=False):
+ """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+
+ >>> df.toJSON().first()
+ '{"age":2,"name":"Alice"}'
+ """
+ rdd = self._jdf.toJSON()
+ return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
+
+ def saveAsParquetFile(self, path):
+ """Save the contents as a Parquet file, preserving the schema.
+
+ Files that are written out using this method can be read back in as
+ a DataFrame using the L{SQLContext.parquetFile} method.
+
+ >>> import tempfile, shutil
+ >>> parquetFile = tempfile.mkdtemp()
+ >>> shutil.rmtree(parquetFile)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df2.collect()) == sorted(df.collect())
+ True
+ """
+ self._jdf.saveAsParquetFile(path)
+
+ def registerTempTable(self, name):
+ """Registers this RDD as a temporary table using the given name.
+
+ The lifetime of this temporary table is tied to the L{SQLContext}
+ that was used to create this DataFrame.
+
+ >>> df.registerTempTable("people")
+ >>> df2 = sqlCtx.sql("select * from people")
+ >>> sorted(df.collect()) == sorted(df2.collect())
+ True
+ """
+ self._jdf.registerTempTable(name)
+
+ def registerAsTable(self, name):
+ """DEPRECATED: use registerTempTable() instead"""
+ warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
+ self.registerTempTable(name)
+
+ def insertInto(self, tableName, overwrite=False):
+ """Inserts the contents of this DataFrame into the specified table.
+
+ Optionally overwriting any existing data.
+ """
+ self._jdf.insertInto(tableName, overwrite)
+
+ def saveAsTable(self, tableName):
+ """Creates a new table with the contents of this DataFrame."""
+ self._jdf.saveAsTable(tableName)
+
+ def schema(self):
+ """Returns the schema of this DataFrame (represented by
+ a L{StructType}).
+
+ >>> df.schema()
+ StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
+ """
+ return _parse_datatype_json_string(self._jdf.schema().json())
+
+ def printSchema(self):
+ """Prints out the schema in the tree format.
+
+ >>> df.printSchema()
+ root
+ |-- age: integer (nullable = true)
+ |-- name: string (nullable = true)
+
+ """
+ print (self._jdf.schema().treeString())
+
+ def count(self):
+ """Return the number of elements in this RDD.
+
+ Unlike the base RDD implementation of count, this implementation
+ leverages the query optimizer to compute the count on the DataFrame,
+ which supports features such as filter pushdown.
+
+ >>> df.count()
+ 2L
+ """
+ return self._jdf.count()
+
+ def collect(self):
+ """Return a list that contains all of the rows.
+
+ Each object in the list is a Row, the fields can be accessed as
+ attributes.
+
+ >>> df.collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ with SCCallSiteSync(self._sc) as css:
+ bytesInJava = self._jdf.javaToPython().collect().iterator()
+ tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+ tempFile.close()
+ self._sc._writeToFile(bytesInJava, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+ os.unlink(tempFile.name)
+ cls = _create_cls(self.schema())
+ return [cls(r) for r in rs]
+
+ def limit(self, num):
+ """Limit the result count to the number specified.
+
+ >>> df.limit(1).collect()
+ [Row(age=2, name=u'Alice')]
+ >>> df.limit(0).collect()
+ []
+ """
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def take(self, num):
+ """Take the first num rows of the RDD.
+
+ Each object in the list is a Row, the fields can be accessed as
+ attributes.
+
+ >>> df.take(2)
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ return self.limit(num).collect()
+
+ def map(self, f):
+ """ Return a new RDD by applying a function to each Row, it's a
+ shorthand for df.rdd.map()
+
+ >>> df.map(lambda p: p.name).collect()
+ [u'Alice', u'Bob']
+ """
+ return self.rdd.map(f)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(iterator): yield 1
+ >>> rdd.mapPartitions(f).sum()
+ 4
+ """
+ return self.rdd.mapPartitions(f, preservesPartitioning)
+
+ def cache(self):
+ """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """
+ self.is_cached = True
+ self._jdf.cache()
+ return self
+
+ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+ """ Set the storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+ """
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
+ return self
+
+ def unpersist(self, blocking=True):
+ """ Mark it as non-persistent, and remove all blocks for it from
+ memory and disk.
+ """
+ self.is_cached = False
+ self._jdf.unpersist(blocking)
+ return self
+
+ # def coalesce(self, numPartitions, shuffle=False):
+ # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+ # return DataFrame(rdd, self.sql_ctx)
+
+ def repartition(self, numPartitions):
+ """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+ partitions.
+ """
+ rdd = self._jdf.repartition(numPartitions, None)
+ return DataFrame(rdd, self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this DataFrame.
+
+ >>> df.sample(False, 0.5, 97).count()
+ 1L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+ return DataFrame(rdd, self.sql_ctx)
+
+ # def takeSample(self, withReplacement, num, seed=None):
+ # """Return a fixed-size sampled subset of this DataFrame.
+ #
+ # >>> df = sqlCtx.inferSchema(rdd)
+ # >>> df.takeSample(False, 2, 97)
+ # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ # """
+ # seed = seed if seed is not None else random.randint(0, sys.maxint)
+ # with SCCallSiteSync(self.context) as css:
+ # bytesInJava = self._jdf \
+ # .takeSampleToPython(withReplacement, num, long(seed)) \
+ # .iterator()
+ # cls = _create_cls(self.schema())
+ # return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+ @property
+ def dtypes(self):
+ """Return all column names and their data types as a list.
+
+ >>> df.dtypes
+ [('age', 'integer'), ('name', 'string')]
+ """
+ return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+
+ @property
+ def columns(self):
+ """ Return all column names as a list.
+
+ >>> df.columns
+ [u'age', u'name']
+ """
+ return [f.name for f in self.schema().fields]
+
+ def join(self, other, joinExprs=None, joinType=None):
+ """
+ Join with another DataFrame, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`::
+
+ :param other: Right side of the join
+ :param joinExprs: Join expression
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+
+ >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
+ [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
+ """
+
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
+ else:
+ assert isinstance(joinExprs, Column), "joinExprs should be Column"
+ if joinType is None:
+ jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ else:
+ assert isinstance(joinType, basestring), "joinType should be basestring"
+ jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def sort(self, *cols):
+ """ Return a new :class:`DataFrame` sorted by the specified column.
+
+ :param cols: The columns or expressions used for sorting
+
+ >>> df.sort(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> df.sortBy(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ sortBy = sort
+
+ def head(self, n=None):
+ """ Return the first `n` rows or the first row if n is None.
+
+ >>> df.head()
+ Row(age=2, name=u'Alice')
+ >>> df.head(1)
+ [Row(age=2, name=u'Alice')]
+ """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def first(self):
+ """ Return the first row.
+
+ >>> df.first()
+ Row(age=2, name=u'Alice')
+ """
+ return self.head()
+
+ def __getitem__(self, item):
+ """ Return the column by given name
+
+ >>> df['age'].collect()
+ [Row(age=2), Row(age=5)]
+ >>> df[ ["name", "age"]].collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df[ df.age > 3 ].collect()
+ [Row(age=5, name=u'Bob')]
+ """
+ if isinstance(item, basestring):
+ jc = self._jdf.apply(item)
+ return Column(jc, self.sql_ctx)
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, list):
+ return self.select(*item)
+ else:
+ raise IndexError("unexpected index: %s" % item)
+
+ def __getattr__(self, name):
+ """ Return the column by given name
+
+ >>> df.age.collect()
+ [Row(age=2), Row(age=5)]
+ """
+ if name.startswith("__"):
+ raise AttributeError(name)
+ jc = self._jdf.apply(name)
+ return Column(jc, self.sql_ctx)
+
+ def select(self, *cols):
+ """ Selecting a set of expressions.
+
+ >>> df.select().collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('*').collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('name', 'age').collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df.select(df.name, (df.age + 10).alias('age')).collect()
+ [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
+ """
+ if not cols:
+ cols = ["*"]
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def selectExpr(self, *expr):
+ """
+ Selects a set of SQL expressions. This is a variant of
+ `select` that accepts SQL expressions.
+
+ >>> df.selectExpr("age * 2", "abs(age)").collect()
+ [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ """
+ jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
+ jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def filter(self, condition):
+ """ Filtering rows using the given condition, which could be
+ Column expression or string of SQL expression.
+
+ where() is an alias for filter().
+
+ >>> df.filter(df.age > 3).collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where(df.age == 2).collect()
+ [Row(age=2, name=u'Alice')]
+
+ >>> df.filter("age > 3").collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where("age = 2").collect()
+ [Row(age=2, name=u'Alice')]
+ """
+ if isinstance(condition, basestring):
+ jdf = self._jdf.filter(condition)
+ elif isinstance(condition, Column):
+ jdf = self._jdf.filter(condition._jc)
+ else:
+ raise TypeError("condition should be string or Column")
+ return DataFrame(jdf, self.sql_ctx)
+
+ where = filter
+
+ def groupBy(self, *cols):
+ """ Group the :class:`DataFrame` using the specified columns,
+ so we can run aggregation on them. See :class:`GroupedData`
+ for all the available aggregate functions.
+
+ >>> df.groupBy().avg().collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+ >>> df.groupBy(df.name).avg().collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+ """
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return GroupedData(jdf, self.sql_ctx)
+
+ def agg(self, *exprs):
+ """ Aggregate on the entire :class:`DataFrame` without groups
+ (shorthand for df.groupBy.agg()).
+
+ >>> df.agg({"age": "max"}).collect()
+ [Row(MAX(age#0)=5)]
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=2)]
+ """
+ return self.groupBy().agg(*exprs)
+
+ def unionAll(self, other):
+ """ Return a new DataFrame containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL.
+ """
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+
+ def intersect(self, other):
+ """ Return a new :class:`DataFrame` containing rows only in
+ both this frame and another frame.
+
+ This is equivalent to `INTERSECT` in SQL.
+ """
+ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+ def subtract(self, other):
+ """ Return a new :class:`DataFrame` containing rows in this frame
+ but not in another frame.
+
+ This is equivalent to `EXCEPT` in SQL.
+ """
+ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+ def addColumn(self, colName, col):
+ """ Return a new :class:`DataFrame` by adding a column.
+
+ >>> df.addColumn('age2', df.age + 2).collect()
+ [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
+ """
+ return self.select('*', col.alias(colName))
+
+ def to_pandas(self):
+ """
+ Collect all the rows and return a `pandas.DataFrame`.
+
+ >>> df.to_pandas() # doctest: +SKIP
+ age name
+ 0 2 Alice
+ 1 5 Bob
+ """
+ import pandas as pd
+ return pd.DataFrame.from_records(self.collect(), columns=self.columns)
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+ """
+ SchemaRDD is deprecated, please use DataFrame
+ """
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedData(object):
+
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by DataFrame.groupBy().
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ def agg(self, *exprs):
+ """ Compute aggregates by specifying a map from column name
+ to aggregate methods.
+
+ The available aggregate methods are `avg`, `max`, `min`,
+ `sum`, `count`.
+
+ :param exprs: list or aggregate columns or a map from column
+ name to aggregate methods.
+
+ >>> gdf = df.groupBy(df.name)
+ >>> gdf.agg({"age": "max"}).collect()
+ [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
+ >>> from pyspark.sql import Dsl
+ >>> gdf.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
+ """
+ assert exprs, "exprs should not be empty"
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jmap = MapConverter().convert(exprs[0],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(jmap)
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+ jcols = ListConverter().convert([c._jc for c in exprs[1:]],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """ Count the number of rows for each group.
+
+ >>> df.groupBy(df.age).count().collect()
+ [Row(age=2, count=1), Row(age=5, count=1)]
+ """
+
+ @dfapi
+ def mean(self):
+ """Compute the average value for each numeric columns
+ for each group. This is an alias for `avg`."""
+
+ @dfapi
+ def avg(self):
+ """Compute the average value for each numeric columns
+ for each group."""
+
+ @dfapi
+ def max(self):
+ """Compute the max value for each numeric columns for
+ each group. """
+
+ @dfapi
+ def min(self):
+ """Compute the min value for each numeric column for
+ each group."""
+
+ @dfapi
+ def sum(self):
+ """Compute the sum for each numeric columns for each
+ group."""
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Dsl.lit(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Dsl.col(name)
+
+
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
+
+
+def _unary_op(name, doc="unary operator"):
+ """ Create a method for given unary operator """
+ def _(self):
+ jc = getattr(self._jc, name)()
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _dsl_op(name, doc=''):
+ def _(self):
+ jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _bin_op(name, doc="binary operator"):
+ """ Create a method for given binary operator
+ """
+ def _(self, other):
+ jc = other._jc if isinstance(other, Column) else other
+ njc = getattr(self._jc, name)(jc)
+ return Column(njc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _reverse_op(name, doc="binary operator"):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ jother = _create_column_from_literal(other)
+ jc = getattr(jother, name)(self._jc)
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+class Column(DataFrame):
+
+ """
+ A column in a DataFrame.
+
+ `Column` instances can be created by::
+
+ # 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
+
+ # 2. Create from an expression
+ df.colName + 1
+ 1 / df.colName
+ """
+
+ def __init__(self, jc, sql_ctx=None):
+ self._jc = jc
+ super(Column, self).__init__(jc, sql_ctx)
+
+ # arithmetic operators
+ __neg__ = _dsl_op("negate")
+ __add__ = _bin_op("plus")
+ __sub__ = _bin_op("minus")
+ __mul__ = _bin_op("multiply")
+ __div__ = _bin_op("divide")
+ __mod__ = _bin_op("mod")
+ __radd__ = _bin_op("plus")
+ __rsub__ = _reverse_op("minus")
+ __rmul__ = _bin_op("multiply")
+ __rdiv__ = _reverse_op("divide")
+ __rmod__ = _reverse_op("mod")
+
+ # logistic operators
+ __eq__ = _bin_op("equalTo")
+ __ne__ = _bin_op("notEqual")
+ __lt__ = _bin_op("lt")
+ __le__ = _bin_op("leq")
+ __ge__ = _bin_op("geq")
+ __gt__ = _bin_op("gt")
+
+ # `and`, `or`, `not` cannot be overloaded in Python,
+ # so use bitwise operators as boolean operators
+ __and__ = _bin_op('and')
+ __or__ = _bin_op('or')
+ __invert__ = _dsl_op('not')
+ __rand__ = _bin_op("and")
+ __ror__ = _bin_op("or")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("getItem")
+ getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+
+ def substr(self, startPos, length):
+ """
+ Return a Column which is a substring of the column
+
+ :param startPos: start position (int or Column)
+ :param length: length of the substring (int or Column)
+
+ >>> df.name.substr(1, 3).collect()
+ [Row(col=u'Ali'), Row(col=u'Bob')]
+ """
+ if type(startPos) != type(length):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+ jc = self._jc.substr(startPos, length)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, length._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc, self.sql_ctx)
+
+ __getslice__ = substr
+
+ # order
+ asc = _unary_op("asc")
+ desc = _unary_op("desc")
+
+ isNull = _unary_op("isNull", "True if the current expression is null.")
+ isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
+
+ def alias(self, alias):
+ """Return a alias for this column
+
+ >>> df.age.alias("age2").collect()
+ [Row(age2=2), Row(age2=5)]
+ """
+ return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+
+ def cast(self, dataType):
+ """ Convert the column into type `dataType`
+
+ >>> df.select(df.age.cast("string").alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ """
+ if self.sql_ctx is None:
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ else:
+ ssql_ctx = self.sql_ctx._ssql_ctx
+ if isinstance(dataType, basestring):
+ jc = self._jc.cast(dataType)
+ elif isinstance(dataType, DataType):
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ jc = self._jc.cast(jdt)
+ return Column(jc, self.sql_ctx)
+
+ def to_pandas(self):
+ """
+ Return a pandas.Series from the column
+
+ >>> df.age.to_pandas() # doctest: +SKIP
+ 0 2
+ 1 5
+ dtype: int64
+ """
+ import pandas as pd
+ data = [c for c, in self.collect()]
+ return pd.Series(data)
+
+
+def _aggregate_func(name, doc=""):
+ """ Create a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return staticmethod(_)
+
+
+class UserDefinedFunction(object):
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+class Dsl(object):
+ """
+ A collections of builtin aggregators
+ """
+ DSLS = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+ }
+
+ for _name, _doc in DSLS.items():
+ locals()[_name] = _aggregate_func(_name, _doc)
+ del _name, _doc
+
+ @staticmethod
+ def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
+ [Row(c=2)]
+
+ >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
+ sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+ @staticmethod
+ def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approxiate distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+ @staticmethod
+ def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.dataframe
+ globs = pyspark.sql.dataframe.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
+ rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
+ globs['df'] = sqlCtx.inferSchema(rdd2)
+ globs['df2'] = sqlCtx.inferSchema(rdd3)
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql/tests.py
similarity index 96%
rename from python/pyspark/sql_tests.py
rename to python/pyspark/sql/tests.py
index d314f46e8d2d5..d25c6365ed067 100644
--- a/python/pyspark/sql_tests.py
+++ b/python/pyspark/sql/tests.py
@@ -34,8 +34,10 @@
else:
import unittest
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType
+
+from pyspark.sql import SQLContext, Column
+from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType, LongType
from pyspark.tests import ReusedPySparkTestCase
@@ -220,7 +222,7 @@ def test_convert_row_to_dict(self):
self.assertEqual(1.0, row.asDict()['d']['key'].c)
def test_infer_schema_with_udt(self):
- from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
df = self.sqlCtx.inferSchema(rdd)
@@ -232,7 +234,7 @@ def test_infer_schema_with_udt(self):
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
- from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = (1.0, ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
@@ -242,7 +244,7 @@ def test_apply_schema_with_udt(self):
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
- from pyspark.sql_tests import ExamplePoint
+ from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
df0 = self.sqlCtx.inferSchema(rdd)
@@ -253,7 +255,6 @@ def test_parquet_with_udt(self):
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_column_operators(self):
- from pyspark.sql import Column, LongType
ci = self.df.key
cs = self.df.value
c = ci == cs
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
new file mode 100644
index 0000000000000..41afefe48ee5e
--- /dev/null
+++ b/python/pyspark/sql/types.py
@@ -0,0 +1,1279 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import decimal
+import datetime
+import keyword
+import warnings
+import json
+import re
+from array import array
+from operator import itemgetter
+
+
+__all__ = [
+ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
+ "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
+ "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ]
+
+
+class DataType(object):
+
+ """Spark SQL DataType"""
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+ def __hash__(self):
+ return hash(str(self))
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.__dict__ == other.__dict__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__[:-4].lower()
+
+ def jsonValue(self):
+ return self.typeName()
+
+ def json(self):
+ return json.dumps(self.jsonValue(),
+ separators=(',', ':'),
+ sort_keys=True)
+
+
+class PrimitiveTypeSingleton(type):
+
+ """Metaclass for PrimitiveType"""
+
+ _instances = {}
+
+ def __call__(cls):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
+ return cls._instances[cls]
+
+
+class PrimitiveType(DataType):
+
+ """Spark SQL PrimitiveType"""
+
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __eq__(self, other):
+ # because they should be the same object
+ return self is other
+
+
+class NullType(PrimitiveType):
+
+ """Spark SQL NullType
+
+ The data type representing None, used for the types which has not
+ been inferred.
+ """
+
+
+class StringType(PrimitiveType):
+
+ """Spark SQL StringType
+
+ The data type representing string values.
+ """
+
+
+class BinaryType(PrimitiveType):
+
+ """Spark SQL BinaryType
+
+ The data type representing bytearray values.
+ """
+
+
+class BooleanType(PrimitiveType):
+
+ """Spark SQL BooleanType
+
+ The data type representing bool values.
+ """
+
+
+class DateType(PrimitiveType):
+
+ """Spark SQL DateType
+
+ The data type representing datetime.date values.
+ """
+
+
+class TimestampType(PrimitiveType):
+
+ """Spark SQL TimestampType
+
+ The data type representing datetime.datetime values.
+ """
+
+
+class DecimalType(DataType):
+
+ """Spark SQL DecimalType
+
+ The data type representing decimal.Decimal values.
+ """
+
+ def __init__(self, precision=None, scale=None):
+ self.precision = precision
+ self.scale = scale
+ self.hasPrecisionInfo = precision is not None
+
+ def jsonValue(self):
+ if self.hasPrecisionInfo:
+ return "decimal(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "decimal"
+
+ def __repr__(self):
+ if self.hasPrecisionInfo:
+ return "DecimalType(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "DecimalType()"
+
+
+class DoubleType(PrimitiveType):
+
+ """Spark SQL DoubleType
+
+ The data type representing float values.
+ """
+
+
+class FloatType(PrimitiveType):
+
+ """Spark SQL FloatType
+
+ The data type representing single precision floating-point values.
+ """
+
+
+class ByteType(PrimitiveType):
+
+ """Spark SQL ByteType
+
+ The data type representing int values with 1 singed byte.
+ """
+
+
+class IntegerType(PrimitiveType):
+
+ """Spark SQL IntegerType
+
+ The data type representing int values.
+ """
+
+
+class LongType(PrimitiveType):
+
+ """Spark SQL LongType
+
+ The data type representing long values. If the any value is
+ beyond the range of [-9223372036854775808, 9223372036854775807],
+ please use DecimalType.
+ """
+
+
+class ShortType(PrimitiveType):
+
+ """Spark SQL ShortType
+
+ The data type representing int values with 2 signed bytes.
+ """
+
+
+class ArrayType(DataType):
+
+ """Spark SQL ArrayType
+
+ The data type representing list values. An ArrayType object
+ comprises two fields, elementType (a DataType) and containsNull (a bool).
+ The field of elementType is used to specify the type of array elements.
+ The field of containsNull is used to specify if the array has None values.
+
+ """
+
+ def __init__(self, elementType, containsNull=True):
+ """Creates an ArrayType
+
+ :param elementType: the data type of elements.
+ :param containsNull: indicates whether the list contains None values.
+
+ >>> ArrayType(StringType) == ArrayType(StringType, True)
+ True
+ >>> ArrayType(StringType, False) == ArrayType(StringType)
+ False
+ """
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self):
+ return "ArrayType(%s,%s)" % (self.elementType,
+ str(self.containsNull).lower())
+
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return ArrayType(_parse_datatype_json_value(json["elementType"]),
+ json["containsNull"])
+
+
+class MapType(DataType):
+
+ """Spark SQL MapType
+
+ The data type representing dict values. A MapType object comprises
+ three fields, keyType (a DataType), valueType (a DataType) and
+ valueContainsNull (a bool).
+
+ The field of keyType is used to specify the type of keys in the map.
+ The field of valueType is used to specify the type of values in the map.
+ The field of valueContainsNull is used to specify if values of this
+ map has None values.
+
+ For values of a MapType column, keys are not allowed to have None values.
+
+ """
+
+ def __init__(self, keyType, valueType, valueContainsNull=True):
+ """Creates a MapType
+ :param keyType: the data type of keys.
+ :param valueType: the data type of values.
+ :param valueContainsNull: indicates whether values contains
+ null values.
+
+ >>> (MapType(StringType, IntegerType)
+ ... == MapType(StringType, IntegerType, True))
+ True
+ >>> (MapType(StringType, IntegerType, False)
+ ... == MapType(StringType, FloatType))
+ False
+ """
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self):
+ return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
+ str(self.valueContainsNull).lower())
+
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return MapType(_parse_datatype_json_value(json["keyType"]),
+ _parse_datatype_json_value(json["valueType"]),
+ json["valueContainsNull"])
+
+
+class StructField(DataType):
+
+ """Spark SQL StructField
+
+ Represents a field in a StructType.
+ A StructField object comprises three fields, name (a string),
+ dataType (a DataType) and nullable (a bool). The field of name
+ is the name of a StructField. The field of dataType specifies
+ the data type of a StructField.
+
+ The field of nullable specifies if values of a StructField can
+ contain None values.
+
+ """
+
+ def __init__(self, name, dataType, nullable=True, metadata=None):
+ """Creates a StructField
+ :param name: the name of this field.
+ :param dataType: the data type of this field.
+ :param nullable: indicates whether values of this field
+ can be null.
+ :param metadata: metadata of this field, which is a map from string
+ to simple type that can be serialized to JSON
+ automatically
+
+ >>> (StructField("f1", StringType, True)
+ ... == StructField("f1", StringType, True))
+ True
+ >>> (StructField("f1", StringType, True)
+ ... == StructField("f2", StringType, True))
+ False
+ """
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+ self.metadata = metadata or {}
+
+ def __repr__(self):
+ return "StructField(%s,%s,%s)" % (self.name, self.dataType,
+ str(self.nullable).lower())
+
+ def jsonValue(self):
+ return {"name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable,
+ "metadata": self.metadata}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructField(json["name"],
+ _parse_datatype_json_value(json["type"]),
+ json["nullable"],
+ json["metadata"])
+
+
+class StructType(DataType):
+
+ """Spark SQL StructType
+
+ The data type representing rows.
+ A StructType object comprises a list of L{StructField}.
+
+ """
+
+ def __init__(self, fields):
+ """Creates a StructType
+
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True),
+ ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 == struct2
+ False
+ """
+ self.fields = fields
+
+ def __repr__(self):
+ return ("StructType(List(%s))" %
+ ",".join(str(field) for field in self.fields))
+
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "fields": [f.jsonValue() for f in self.fields]}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructType([StructField.fromJson(f) for f in json["fields"]])
+
+
+class UserDefinedType(DataType):
+ """
+ .. note:: WARN: Spark Internal Use Only
+ SQL User-Defined Type (UDT).
+ """
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sqlType(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sqlType().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def scalaUDT(cls):
+ """
+ The class name of the paired Scala UDT.
+ """
+ raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+ def json(self):
+ return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+ def jsonValue(self):
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ return schema
+
+ @classmethod
+ def fromJson(cls, json):
+ pyUDT = json["pyClass"]
+ split = pyUDT.rfind(".")
+ pyModule = pyUDT[:split]
+ pyClass = pyUDT[split+1:]
+ m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ UDT = getattr(m, pyClass)
+ return UDT()
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+
+_all_primitive_types = dict((v.typeName(), v)
+ for v in globals().itervalues()
+ if type(v) is PrimitiveTypeSingleton and
+ v.__base__ == PrimitiveType)
+
+
+_all_complex_types = dict((v.typeName(), v)
+ for v in [ArrayType, MapType, StructType])
+
+
+def _parse_datatype_json_string(json_string):
+ """Parses the given data type JSON string.
+ >>> def check_datatype(datatype):
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
+ ... return datatype == python_datatype
+ >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
+ True
+ >>> # Simple ArrayType.
+ >>> simple_arraytype = ArrayType(StringType(), True)
+ >>> check_datatype(simple_arraytype)
+ True
+ >>> # Simple MapType.
+ >>> simple_maptype = MapType(StringType(), LongType())
+ >>> check_datatype(simple_maptype)
+ True
+ >>> # Simple StructType.
+ >>> simple_structtype = StructType([
+ ... StructField("a", DecimalType(), False),
+ ... StructField("b", BooleanType(), True),
+ ... StructField("c", LongType(), True),
+ ... StructField("d", BinaryType(), False)])
+ >>> check_datatype(simple_structtype)
+ True
+ >>> # Complex StructType.
+ >>> complex_structtype = StructType([
+ ... StructField("simpleArray", simple_arraytype, True),
+ ... StructField("simpleMap", simple_maptype, True),
+ ... StructField("simpleStruct", simple_structtype, True),
+ ... StructField("boolean", BooleanType(), False),
+ ... StructField("withMeta", DoubleType(), False, {"name": "age"})])
+ >>> check_datatype(complex_structtype)
+ True
+ >>> # Complex ArrayType.
+ >>> complex_arraytype = ArrayType(complex_structtype, True)
+ >>> check_datatype(complex_arraytype)
+ True
+ >>> # Complex MapType.
+ >>> complex_maptype = MapType(complex_structtype,
+ ... complex_arraytype, False)
+ >>> check_datatype(complex_maptype)
+ True
+ >>> check_datatype(ExamplePointUDT())
+ True
+ >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> check_datatype(structtype_with_udt)
+ True
+ """
+ return _parse_datatype_json_value(json.loads(json_string))
+
+
+_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
+
+
+def _parse_datatype_json_value(json_value):
+ if type(json_value) is unicode:
+ if json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ elif json_value == u'decimal':
+ return DecimalType()
+ elif _FIXED_DECIMAL.match(json_value):
+ m = _FIXED_DECIMAL.match(json_value)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % json_value)
+ else:
+ tpe = json_value["type"]
+ if tpe in _all_complex_types:
+ return _all_complex_types[tpe].fromJson(json_value)
+ elif tpe == 'udt':
+ return UserDefinedType.fromJson(json_value)
+ else:
+ raise ValueError("not supported type: %s" % tpe)
+
+
+# Mapping Python types to Spark SQL DataType
+_type_mappings = {
+ type(None): NullType,
+ bool: BooleanType,
+ int: IntegerType,
+ long: LongType,
+ float: DoubleType,
+ str: StringType,
+ unicode: StringType,
+ bytearray: BinaryType,
+ decimal.Decimal: DecimalType,
+ datetime.date: DateType,
+ datetime.datetime: TimestampType,
+ datetime.time: TimestampType,
+}
+
+
+def _infer_type(obj):
+ """Infer the DataType from obj
+
+ >>> p = ExamplePoint(1.0, 2.0)
+ >>> _infer_type(p)
+ ExamplePointUDT
+ """
+ if obj is None:
+ raise ValueError("Can not infer type for None")
+
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
+ dataType = _type_mappings.get(type(obj))
+ if dataType is not None:
+ return dataType()
+
+ if isinstance(obj, dict):
+ for key, value in obj.iteritems():
+ if key is not None and value is not None:
+ return MapType(_infer_type(key), _infer_type(value), True)
+ else:
+ return MapType(NullType(), NullType(), True)
+ elif isinstance(obj, (list, array)):
+ for v in obj:
+ if v is not None:
+ return ArrayType(_infer_type(obj[0]), True)
+ else:
+ return ArrayType(NullType(), True)
+ else:
+ try:
+ return _infer_schema(obj)
+ except ValueError:
+ raise ValueError("not supported type: %s" % type(obj))
+
+
+def _infer_schema(row):
+ """Infer the schema from dict/namedtuple/object"""
+ if isinstance(row, dict):
+ items = sorted(row.items())
+
+ elif isinstance(row, tuple):
+ if hasattr(row, "_fields"): # namedtuple
+ items = zip(row._fields, tuple(row))
+ elif hasattr(row, "__FIELDS__"): # Row
+ items = zip(row.__FIELDS__, tuple(row))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
+ items = row
+ else:
+ raise ValueError("Can't infer schema from tuple")
+
+ elif hasattr(row, "__dict__"): # object
+ items = sorted(row.__dict__.items())
+
+ else:
+ raise ValueError("Can not infer schema for type: %s" % type(row))
+
+ fields = [StructField(k, _infer_type(v), True) for k, v in items]
+ return StructType(fields)
+
+
+def _need_python_to_sql_conversion(dataType):
+ """
+ Checks whether we need python to sql conversion for the given type.
+ For now, only UDTs need this conversion.
+
+ >>> _need_python_to_sql_conversion(DoubleType())
+ False
+ >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+ ... StructField("values", ArrayType(DoubleType(), False), False)])
+ >>> _need_python_to_sql_conversion(schema0)
+ False
+ >>> _need_python_to_sql_conversion(ExamplePointUDT())
+ True
+ >>> schema1 = ArrayType(ExamplePointUDT(), False)
+ >>> _need_python_to_sql_conversion(schema1)
+ True
+ >>> schema2 = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> _need_python_to_sql_conversion(schema2)
+ True
+ """
+ if isinstance(dataType, StructType):
+ return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ elif isinstance(dataType, ArrayType):
+ return _need_python_to_sql_conversion(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_python_to_sql_conversion(dataType.keyType) or \
+ _need_python_to_sql_conversion(dataType.valueType)
+ elif isinstance(dataType, UserDefinedType):
+ return True
+ else:
+ return False
+
+
+def _python_to_sql_converter(dataType):
+ """
+ Returns a converter that converts a Python object into a SQL datum for the given type.
+
+ >>> conv = _python_to_sql_converter(DoubleType())
+ >>> conv(1.0)
+ 1.0
+ >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+ >>> conv([1.0, 2.0])
+ [1.0, 2.0]
+ >>> conv = _python_to_sql_converter(ExamplePointUDT())
+ >>> conv(ExamplePoint(1.0, 2.0))
+ [1.0, 2.0]
+ >>> schema = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> conv = _python_to_sql_converter(schema)
+ >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+ (1.0, [1.0, 2.0])
+ """
+ if not _need_python_to_sql_conversion(dataType):
+ return lambda x: x
+
+ if isinstance(dataType, StructType):
+ names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+ converters = map(_python_to_sql_converter, types)
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
+ d = dict(obj)
+ return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return converter
+ elif isinstance(dataType, ArrayType):
+ element_converter = _python_to_sql_converter(dataType.elementType)
+ return lambda a: [element_converter(v) for v in a]
+ elif isinstance(dataType, MapType):
+ key_converter = _python_to_sql_converter(dataType.keyType)
+ value_converter = _python_to_sql_converter(dataType.valueType)
+ return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ elif isinstance(dataType, UserDefinedType):
+ return lambda obj: dataType.serialize(obj)
+ else:
+ raise ValueError("Unexpected type %r" % dataType)
+
+
+def _has_nulltype(dt):
+ """ Return whether there is NullType in `dt` or not """
+ if isinstance(dt, StructType):
+ return any(_has_nulltype(f.dataType) for f in dt.fields)
+ elif isinstance(dt, ArrayType):
+ return _has_nulltype((dt.elementType))
+ elif isinstance(dt, MapType):
+ return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+ else:
+ return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+ if isinstance(a, NullType):
+ return b
+ elif isinstance(b, NullType):
+ return a
+ elif type(a) is not type(b):
+ # TODO: type cast (such as int -> long)
+ raise TypeError("Can not merge type %s and %s" % (a, b))
+
+ # same type
+ if isinstance(a, StructType):
+ nfs = dict((f.name, f.dataType) for f in b.fields)
+ fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+ for f in a.fields]
+ names = set([f.name for f in fields])
+ for n in nfs:
+ if n not in names:
+ fields.append(StructField(n, nfs[n]))
+ return StructType(fields)
+
+ elif isinstance(a, ArrayType):
+ return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+ elif isinstance(a, MapType):
+ return MapType(_merge_type(a.keyType, b.keyType),
+ _merge_type(a.valueType, b.valueType),
+ True)
+ else:
+ return a
+
+
+def _create_converter(dataType):
+ """Create an converter to drop the names of fields in obj """
+ if isinstance(dataType, ArrayType):
+ conv = _create_converter(dataType.elementType)
+ return lambda row: map(conv, row)
+
+ elif isinstance(dataType, MapType):
+ kconv = _create_converter(dataType.keyType)
+ vconv = _create_converter(dataType.valueType)
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
+
+ elif isinstance(dataType, NullType):
+ return lambda x: None
+
+ elif not isinstance(dataType, StructType):
+ return lambda x: x
+
+ # dataType must be StructType
+ names = [f.name for f in dataType.fields]
+ converters = [_create_converter(f.dataType) for f in dataType.fields]
+
+ def convert_struct(obj):
+ if obj is None:
+ return
+
+ if isinstance(obj, tuple):
+ if hasattr(obj, "_fields"):
+ d = dict(zip(obj._fields, obj))
+ elif hasattr(obj, "__FIELDS__"):
+ d = dict(zip(obj.__FIELDS__, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
+ d = dict(obj)
+ else:
+ raise ValueError("unexpected tuple: %s" % str(obj))
+
+ elif isinstance(obj, dict):
+ d = obj
+ elif hasattr(obj, "__dict__"): # object
+ d = obj.__dict__
+ else:
+ raise ValueError("Unexpected obj: %s" % obj)
+
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+
+ return convert_struct
+
+
+_BRACKETS = {'(': ')', '[': ']', '{': '}'}
+
+
+def _split_schema_abstract(s):
+ """
+ split the schema abstract into fields
+
+ >>> _split_schema_abstract("a b c")
+ ['a', 'b', 'c']
+ >>> _split_schema_abstract("a(a b)")
+ ['a(a b)']
+ >>> _split_schema_abstract("a b[] c{a b}")
+ ['a', 'b[]', 'c{a b}']
+ >>> _split_schema_abstract(" ")
+ []
+ """
+
+ r = []
+ w = ''
+ brackets = []
+ for c in s:
+ if c == ' ' and not brackets:
+ if w:
+ r.append(w)
+ w = ''
+ else:
+ w += c
+ if c in _BRACKETS:
+ brackets.append(c)
+ elif c in _BRACKETS.values():
+ if not brackets or c != _BRACKETS[brackets.pop()]:
+ raise ValueError("unexpected " + c)
+
+ if brackets:
+ raise ValueError("brackets not closed: %s" % brackets)
+ if w:
+ r.append(w)
+ return r
+
+
+def _parse_field_abstract(s):
+ """
+ Parse a field in schema abstract
+
+ >>> _parse_field_abstract("a")
+ StructField(a,None,true)
+ >>> _parse_field_abstract("b(c d)")
+ StructField(b,StructType(...c,None,true),StructField(d...
+ >>> _parse_field_abstract("a[]")
+ StructField(a,ArrayType(None,true),true)
+ >>> _parse_field_abstract("a{[]}")
+ StructField(a,MapType(None,ArrayType(None,true),true),true)
+ """
+ if set(_BRACKETS.keys()) & set(s):
+ idx = min((s.index(c) for c in _BRACKETS if c in s))
+ name = s[:idx]
+ return StructField(name, _parse_schema_abstract(s[idx:]), True)
+ else:
+ return StructField(s, None, True)
+
+
+def _parse_schema_abstract(s):
+ """
+ parse abstract into schema
+
+ >>> _parse_schema_abstract("a b c")
+ StructType...a...b...c...
+ >>> _parse_schema_abstract("a[b c] b{}")
+ StructType...a,ArrayType...b...c...b,MapType...
+ >>> _parse_schema_abstract("c{} d{a b}")
+ StructType...c,MapType...d,MapType...a...b...
+ >>> _parse_schema_abstract("a b(t)").fields[1]
+ StructField(b,StructType(List(StructField(t,None,true))),true)
+ """
+ s = s.strip()
+ if not s:
+ return
+
+ elif s.startswith('('):
+ return _parse_schema_abstract(s[1:-1])
+
+ elif s.startswith('['):
+ return ArrayType(_parse_schema_abstract(s[1:-1]), True)
+
+ elif s.startswith('{'):
+ return MapType(None, _parse_schema_abstract(s[1:-1]))
+
+ parts = _split_schema_abstract(s)
+ fields = [_parse_field_abstract(p) for p in parts]
+ return StructType(fields)
+
+
+def _infer_schema_type(obj, dataType):
+ """
+ Fill the dataType with types inferred from obj
+
+ >>> schema = _parse_schema_abstract("a b c d")
+ >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
+ >>> _infer_schema_type(row, schema)
+ StructType...IntegerType...DoubleType...StringType...DateType...
+ >>> row = [[1], {"key": (1, 2.0)}]
+ >>> schema = _parse_schema_abstract("a[] b{c d}")
+ >>> _infer_schema_type(row, schema)
+ StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
+ """
+ if dataType is None:
+ return _infer_type(obj)
+
+ if not obj:
+ return NullType()
+
+ if isinstance(dataType, ArrayType):
+ eType = _infer_schema_type(obj[0], dataType.elementType)
+ return ArrayType(eType, True)
+
+ elif isinstance(dataType, MapType):
+ k, v = obj.iteritems().next()
+ return MapType(_infer_schema_type(k, dataType.keyType),
+ _infer_schema_type(v, dataType.valueType))
+
+ elif isinstance(dataType, StructType):
+ fs = dataType.fields
+ assert len(fs) == len(obj), \
+ "Obj(%s) have different length with fields(%s)" % (obj, fs)
+ fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
+ for o, f in zip(obj, fs)]
+ return StructType(fields)
+
+ else:
+ raise ValueError("Unexpected dataType: %s" % dataType)
+
+
+_acceptable_types = {
+ BooleanType: (bool,),
+ ByteType: (int, long),
+ ShortType: (int, long),
+ IntegerType: (int, long),
+ LongType: (int, long),
+ FloatType: (float,),
+ DoubleType: (float,),
+ DecimalType: (decimal.Decimal,),
+ StringType: (str, unicode),
+ BinaryType: (bytearray,),
+ DateType: (datetime.date,),
+ TimestampType: (datetime.datetime,),
+ ArrayType: (list, tuple, array),
+ MapType: (dict,),
+ StructType: (tuple, list),
+}
+
+
+def _verify_type(obj, dataType):
+ """
+ Verify the type of obj against dataType, raise an exception if
+ they do not match.
+
+ >>> _verify_type(None, StructType([]))
+ >>> _verify_type("", StringType())
+ >>> _verify_type(0, IntegerType())
+ >>> _verify_type(range(3), ArrayType(ShortType()))
+ >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ TypeError:...
+ >>> _verify_type({}, MapType(StringType(), IntegerType()))
+ >>> _verify_type((), StructType([]))
+ >>> _verify_type([], StructType([]))
+ >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ """
+ # all objects are nullable
+ if obj is None:
+ return
+
+ if isinstance(dataType, UserDefinedType):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+ _verify_type(dataType.serialize(obj), dataType.sqlType())
+ return
+
+ _type = type(dataType)
+ assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+
+ # subclass of them can not be deserialized in JVM
+ if type(obj) not in _acceptable_types[_type]:
+ raise TypeError("%s can not accept object in type %s"
+ % (dataType, type(obj)))
+
+ if isinstance(dataType, ArrayType):
+ for i in obj:
+ _verify_type(i, dataType.elementType)
+
+ elif isinstance(dataType, MapType):
+ for k, v in obj.iteritems():
+ _verify_type(k, dataType.keyType)
+ _verify_type(v, dataType.valueType)
+
+ elif isinstance(dataType, StructType):
+ if len(obj) != len(dataType.fields):
+ raise ValueError("Length of object (%d) does not match with"
+ "length of fields (%d)" % (len(obj), len(dataType.fields)))
+ for v, f in zip(obj, dataType.fields):
+ _verify_type(v, f.dataType)
+
+
+_cached_cls = {}
+
+
+def _restore_object(dataType, obj):
+ """ Restore object during unpickling. """
+ # use id(dataType) as key to speed up lookup in dict
+ # Because of batched pickling, dataType will be the
+ # same object in most cases.
+ k = id(dataType)
+ cls = _cached_cls.get(k)
+ if cls is None:
+ # use dataType as key to avoid create multiple class
+ cls = _cached_cls.get(dataType)
+ if cls is None:
+ cls = _create_cls(dataType)
+ _cached_cls[dataType] = cls
+ _cached_cls[k] = cls
+ return cls(obj)
+
+
+def _create_object(cls, v):
+ """ Create an customized object with class `cls`. """
+ # datetime.date would be deserialized as datetime.datetime
+ # from java type, so we need to set it back.
+ if cls is datetime.date and isinstance(v, datetime.datetime):
+ return v.date()
+ return cls(v) if v is not None else v
+
+
+def _create_getter(dt, i):
+ """ Create a getter for item `i` with schema """
+ cls = _create_cls(dt)
+
+ def getter(self):
+ return _create_object(cls, self[i])
+
+ return getter
+
+
+def _has_struct_or_date(dt):
+ """Return whether `dt` is or has StructType/DateType in it"""
+ if isinstance(dt, StructType):
+ return True
+ elif isinstance(dt, ArrayType):
+ return _has_struct_or_date(dt.elementType)
+ elif isinstance(dt, MapType):
+ return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
+ elif isinstance(dt, DateType):
+ return True
+ elif isinstance(dt, UserDefinedType):
+ return True
+ return False
+
+
+def _create_properties(fields):
+ """Create properties according to fields"""
+ ps = {}
+ for i, f in enumerate(fields):
+ name = f.name
+ if (name.startswith("__") and name.endswith("__")
+ or keyword.iskeyword(name)):
+ warnings.warn("field name %s can not be accessed in Python,"
+ "use position to access it instead" % name)
+ if _has_struct_or_date(f.dataType):
+ # delay creating object until accessing it
+ getter = _create_getter(f.dataType, i)
+ else:
+ getter = itemgetter(i)
+ ps[name] = property(getter)
+ return ps
+
+
+def _create_cls(dataType):
+ """
+ Create an class by dataType
+
+ The created class is similar to namedtuple, but can have nested schema.
+
+ >>> schema = _parse_schema_abstract("a b c")
+ >>> row = (1, 1.0, "str")
+ >>> schema = _infer_schema_type(row, schema)
+ >>> obj = _create_cls(schema)(row)
+ >>> import pickle
+ >>> pickle.loads(pickle.dumps(obj))
+ Row(a=1, b=1.0, c='str')
+
+ >>> row = [[1], {"key": (1, 2.0)}]
+ >>> schema = _parse_schema_abstract("a[] b{c d}")
+ >>> schema = _infer_schema_type(row, schema)
+ >>> obj = _create_cls(schema)(row)
+ >>> pickle.loads(pickle.dumps(obj))
+ Row(a=[1], b={'key': Row(c=1, d=2.0)})
+ >>> pickle.loads(pickle.dumps(obj.a))
+ [1]
+ >>> pickle.loads(pickle.dumps(obj.b))
+ {'key': Row(c=1, d=2.0)}
+ """
+
+ if isinstance(dataType, ArrayType):
+ cls = _create_cls(dataType.elementType)
+
+ def List(l):
+ if l is None:
+ return
+ return [_create_object(cls, v) for v in l]
+
+ return List
+
+ elif isinstance(dataType, MapType):
+ kcls = _create_cls(dataType.keyType)
+ vcls = _create_cls(dataType.valueType)
+
+ def Dict(d):
+ if d is None:
+ return
+ return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
+
+ return Dict
+
+ elif isinstance(dataType, DateType):
+ return datetime.date
+
+ elif isinstance(dataType, UserDefinedType):
+ return lambda datum: dataType.deserialize(datum)
+
+ elif not isinstance(dataType, StructType):
+ # no wrapper for primitive types
+ return lambda x: x
+
+ class Row(tuple):
+
+ """ Row in DataFrame """
+ __DATATYPE__ = dataType
+ __FIELDS__ = tuple(f.name for f in dataType.fields)
+ __slots__ = ()
+
+ # create property for fast access
+ locals().update(_create_properties(dataType.fields))
+
+ def asDict(self):
+ """ Return as a dict """
+ return dict((n, getattr(self, n)) for n in self.__FIELDS__)
+
+ def __repr__(self):
+ # call collect __repr__ for nested objects
+ return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
+ for n in self.__FIELDS__))
+
+ def __reduce__(self):
+ return (_restore_object, (self.__DATATYPE__, tuple(self)))
+
+ return Row
+
+
+def _create_row(fields, values):
+ row = Row(*values)
+ row.__FIELDS__ = fields
+ return row
+
+
+class Row(tuple):
+
+ """
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
+
+ Row can be used to create a row object by using named arguments,
+ the fields will be sorted by names.
+
+ >>> row = Row(name="Alice", age=11)
+ >>> row
+ Row(age=11, name='Alice')
+ >>> row.name, row.age
+ ('Alice', 11)
+
+ Row also can be used to create another Row like class, then it
+ could be used to create Row objects, such as
+
+ >>> Person = Row("name", "age")
+ >>> Person
+
+ >>> Person("Alice", 11)
+ Row(name='Alice', age=11)
+ """
+
+ def __new__(self, *args, **kwargs):
+ if args and kwargs:
+ raise ValueError("Can not use both args "
+ "and kwargs to create Row")
+ if args:
+ # create row class or objects
+ return tuple.__new__(self, args)
+
+ elif kwargs:
+ # create row objects
+ names = sorted(kwargs.keys())
+ values = tuple(kwargs[n] for n in names)
+ row = tuple.__new__(self, values)
+ row.__FIELDS__ = names
+ return row
+
+ else:
+ raise ValueError("No args or kwargs")
+
+ def asDict(self):
+ """
+ Return as an dict
+ """
+ if not hasattr(self, "__FIELDS__"):
+ raise TypeError("Cannot convert a Row class into dict")
+ return dict(zip(self.__FIELDS__, self))
+
+ # let obect acs like class
+ def __call__(self, *args):
+ """create new Row object"""
+ return _create_row(self, args)
+
+ def __getattr__(self, item):
+ if item.startswith("__"):
+ raise AttributeError(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self.__FIELDS__.index(item)
+ return self[idx]
+ except IndexError:
+ raise AttributeError(item)
+
+ def __reduce__(self):
+ if hasattr(self, "__FIELDS__"):
+ return (_create_row, (self.__FIELDS__, tuple(self)))
+ else:
+ return tuple.__reduce__(self)
+
+ def __repr__(self):
+ if hasattr(self, "__FIELDS__"):
+ return "Row(%s)" % ", ".join("%s=%r" % (k, v)
+ for k, v in zip(self.__FIELDS__, self))
+ else:
+ return "" % ", ".join(self)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ # let doctest run in pyspark.sql.types, so DataTypes can be picklable
+ import pyspark.sql.types
+ from pyspark.sql import Row, SQLContext
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+ globs = pyspark.sql.types.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['ExamplePoint'] = ExamplePoint
+ globs['ExamplePointUDT'] = ExamplePointUDT
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/run-tests b/python/run-tests
index 649a2c44d187b..58a26dd8ff088 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -64,8 +64,10 @@ function run_core_tests() {
function run_sql_tests() {
echo "Run sql tests ..."
- run_test "pyspark/sql.py"
- run_test "pyspark/sql_tests.py"
+ run_test "pyspark/sql/types.py"
+ run_test "pyspark/sql/context.py"
+ run_test "pyspark/sql/dataframe.py"
+ run_test "pyspark/sql/tests.py"
}
function run_mllib_tests() {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index e6f622e87f7a4..eb045e37bf5a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, false)
- override def pyUDT: String = "pyspark.sql_tests.ExamplePointUDT"
+ override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
override def serialize(obj: Any): Seq[Double] = {
obj match {
From 31d435ecfdc24a788a6e38f4e82767bc275a3283 Mon Sep 17 00:00:00 2001
From: KaiXinXiaoLei
Date: Mon, 9 Feb 2015 20:58:58 -0800
Subject: [PATCH 027/817] Add a config option to print DAG.
Add a config option "spark.rddDebug.enable" to check whether to print DAG info. When "spark.rddDebug.enable" is true, it will print information about DAG in the log.
Author: KaiXinXiaoLei
Closes #4257 from KaiXinXiaoLei/DAGprint and squashes the following commits:
d9fe42e [KaiXinXiaoLei] change log info
c27ee76 [KaiXinXiaoLei] change log info
83c2b32 [KaiXinXiaoLei] change config option
adcb14f [KaiXinXiaoLei] change the file.
f4e7b9e [KaiXinXiaoLei] add a option to print DAG
---
core/src/main/scala/org/apache/spark/SparkContext.scala | 3 +++
1 file changed, 3 insertions(+)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 71bdbc9b38ddb..8d3c3d000adf3 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1420,6 +1420,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite.shortForm)
+ if (conf.getBoolean("spark.logLineage", false)) {
+ logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
+ }
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
progressBar.foreach(_.finishAll())
From 36c4e1d75933dc843acb747b91dc12e75ad1df42 Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Mon, 9 Feb 2015 21:13:58 -0800
Subject: [PATCH 028/817] SPARK-4900 [MLLIB] MLlib SingularValueDecomposition
ARPACK IllegalStateException
Fix ARPACK error code mapping, at least. It's not yet clear whether the error is what we expect from ARPACK. If it isn't, not sure if that's to be treated as an MLlib or Breeze issue.
Author: Sean Owen
Closes #4485 from srowen/SPARK-4900 and squashes the following commits:
7355aa1 [Sean Owen] Fix ARPACK error code mapping
---
.../org/apache/spark/mllib/linalg/EigenValueDecomposition.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 9d6f97528148e..866936aa4f118 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -117,7 +117,7 @@ private[mllib] object EigenValueDecomposition {
info.`val` match {
case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
" Maximum number of iterations taken. (Refer ARPACK user guide for details)")
- case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
" No shifts could be applied. Try to increase NCV. " +
"(Refer ARPACK user guide for details)")
case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
From 20a6013106b56a1a1cc3e8cda092330ffbe77cc3 Mon Sep 17 00:00:00 2001
From: Marcelo Vanzin
Date: Mon, 9 Feb 2015 21:17:06 -0800
Subject: [PATCH 029/817] [SPARK-2996] Implement userClassPathFirst for driver,
yarn.
Yarn's config option `spark.yarn.user.classpath.first` does not work the same way as
`spark.files.userClassPathFirst`; Yarn's version is a lot more dangerous, in that it
modifies the system classpath, instead of restricting the changes to the user's class
loader. So this change implements the behavior of the latter for Yarn, and deprecates
the more dangerous choice.
To be able to achieve feature-parity, I also implemented the option for drivers (the existing
option only applies to executors). So now there are two options, each controlling whether
to apply userClassPathFirst to the driver or executors. The old option was deprecated, and
aliased to the new one (`spark.executor.userClassPathFirst`).
The existing "child-first" class loader also had to be fixed. It didn't handle resources, and it
was also doing some things that ended up causing JVM errors depending on how things
were being called.
Author: Marcelo Vanzin
Closes #3233 from vanzin/SPARK-2996 and squashes the following commits:
9cf9cf1 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
a1499e2 [Marcelo Vanzin] Remove SPARK_HOME propagation.
fa7df88 [Marcelo Vanzin] Remove 'test.resource' file, create it dynamically.
a8c69f1 [Marcelo Vanzin] Review feedback.
cabf962 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
a1b8d7e [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
3f768e3 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
2ce3c7a [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
0e6d6be [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
70d4044 [Marcelo Vanzin] Fix pyspark/yarn-cluster test.
0fe7777 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
0e6ef19 [Marcelo Vanzin] Move class loaders around and make names more meaninful.
fe970a7 [Marcelo Vanzin] Review feedback.
25d4fed [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
3cb6498 [Marcelo Vanzin] Call the right loadClass() method on the parent.
fbb8ab5 [Marcelo Vanzin] Add locking in loadClass() to avoid deadlocks.
2e6c4b7 [Marcelo Vanzin] Mention new setting in documentation.
b6497f9 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
a10f379 [Marcelo Vanzin] Some feedback.
3730151 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
f513871 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
44010b6 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
7b57cba [Marcelo Vanzin] Remove now outdated message.
5304d64 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
35949c8 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
54e1a98 [Marcelo Vanzin] Merge branch 'master' into SPARK-2996
d1273b2 [Marcelo Vanzin] Add test file to rat exclude.
fa1aafa [Marcelo Vanzin] Remove write check on user jars.
89d8072 [Marcelo Vanzin] Cleanups.
a963ea3 [Marcelo Vanzin] Implement spark.driver.userClassPathFirst for standalone cluster mode.
50afa5f [Marcelo Vanzin] Fix Yarn executor command line.
7d14397 [Marcelo Vanzin] Register user jars in executor up front.
7f8603c [Marcelo Vanzin] Fix yarn-cluster mode without userClassPathFirst.
20373f5 [Marcelo Vanzin] Fix ClientBaseSuite.
55c88fa [Marcelo Vanzin] Run all Yarn integration tests via spark-submit.
0b64d92 [Marcelo Vanzin] Add deprecation warning to yarn option.
4a84d87 [Marcelo Vanzin] Fix the child-first class loader.
d0394b8 [Marcelo Vanzin] Add "deprecated configs" to SparkConf.
46d8cf2 [Marcelo Vanzin] Update doc with new option, change name to "userClassPathFirst".
a314f2d [Marcelo Vanzin] Enable driver class path isolation in SparkSubmit.
91f7e54 [Marcelo Vanzin] [yarn] Enable executor class path isolation.
a853e74 [Marcelo Vanzin] Re-work CoarseGrainedExecutorBackend command line arguments.
89522ef [Marcelo Vanzin] Add class path isolation support for Yarn cluster mode.
---
.../scala/org/apache/spark/SparkConf.scala | 83 +++++-
.../scala/org/apache/spark/TestUtils.scala | 19 +-
.../org/apache/spark/deploy/Client.scala | 5 +-
.../org/apache/spark/deploy/SparkSubmit.scala | 8 +-
.../spark/deploy/master/ui/MasterPage.scala | 2 +-
.../deploy/rest/StandaloneRestServer.scala | 2 +-
.../spark/deploy/worker/DriverRunner.scala | 15 +-
.../spark/deploy/worker/DriverWrapper.scala | 20 +-
.../CoarseGrainedExecutorBackend.scala | 83 +++++-
.../org/apache/spark/executor/Executor.scala | 52 ++--
.../executor/ExecutorURLClassLoader.scala | 84 ------
.../cluster/SparkDeploySchedulerBackend.scala | 9 +-
.../mesos/CoarseMesosSchedulerBackend.scala | 21 +-
.../spark/util/MutableURLClassLoader.scala | 103 +++++++
.../apache/spark/util/ParentClassLoader.scala | 7 +-
.../org/apache/spark/SparkConfSuite.scala | 12 +
.../spark/deploy/SparkSubmitSuite.scala | 27 ++
.../MutableURLClassLoaderSuite.scala} | 12 +-
docs/configuration.md | 31 +-
pom.xml | 12 +-
project/SparkBuild.scala | 8 +-
.../spark/deploy/yarn/ApplicationMaster.scala | 25 +-
.../org/apache/spark/deploy/yarn/Client.scala | 133 ++++-----
.../spark/deploy/yarn/ExecutorRunnable.scala | 25 +-
yarn/src/test/resources/log4j.properties | 4 +-
.../spark/deploy/yarn/ClientSuite.scala | 6 +-
.../spark/deploy/yarn/YarnClusterSuite.scala | 276 ++++++++++++------
27 files changed, 736 insertions(+), 348 deletions(-)
delete mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
create mode 100644 core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
rename core/src/test/scala/org/apache/spark/{executor/ExecutorURLClassLoaderSuite.scala => util/MutableURLClassLoaderSuite.scala} (90%)
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 13aa9960ac33a..0dbd26146cb13 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable.LinkedHashSet
@@ -67,7 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings.put(key, value)
+ settings.put(translateConfKey(key, warn = true), value)
this
}
@@ -139,7 +140,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- settings.putIfAbsent(key, value)
+ settings.putIfAbsent(translateConfKey(key, warn = true), value)
this
}
@@ -175,7 +176,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- Option(settings.get(key))
+ Option(settings.get(translateConfKey(key)))
}
/** Get all parameters as a list of pairs */
@@ -228,7 +229,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.containsKey(key)
+ def contains(key: String): Boolean = settings.containsKey(translateConfKey(key))
/** Copy this object */
override def clone: SparkConf = {
@@ -285,7 +286,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
// Validate memory fractions
val memoryKeys = Seq(
"spark.storage.memoryFraction",
- "spark.shuffle.memoryFraction",
+ "spark.shuffle.memoryFraction",
"spark.shuffle.safetyFraction",
"spark.storage.unrollFraction",
"spark.storage.safetyFraction")
@@ -351,9 +352,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def toDebugString: String = {
getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
+
}
-private[spark] object SparkConf {
+private[spark] object SparkConf extends Logging {
+
+ private val deprecatedConfigs: Map[String, DeprecatedConfig] = {
+ val configs = Seq(
+ DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst",
+ "1.3"),
+ DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3",
+ "Use spark.{driver,executor}.userClassPathFirst instead."))
+ configs.map { x => (x.oldName, x) }.toMap
+ }
+
/**
* Return whether the given config is an akka config (e.g. akka.actor.provider).
* Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout).
@@ -380,4 +392,63 @@ private[spark] object SparkConf {
def isSparkPortConf(name: String): Boolean = {
(name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
}
+
+ /**
+ * Translate the configuration key if it is deprecated and has a replacement, otherwise just
+ * returns the provided key.
+ *
+ * @param userKey Configuration key from the user / caller.
+ * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed
+ * only once for each key.
+ */
+ def translateConfKey(userKey: String, warn: Boolean = false): String = {
+ deprecatedConfigs.get(userKey)
+ .map { deprecatedKey =>
+ if (warn) {
+ deprecatedKey.warn()
+ }
+ deprecatedKey.newName.getOrElse(userKey)
+ }.getOrElse(userKey)
+ }
+
+ /**
+ * Holds information about keys that have been deprecated or renamed.
+ *
+ * @param oldName Old configuration key.
+ * @param newName New configuration key, or `null` if key has no replacement, in which case the
+ * deprecated key will be used (but the warning message will still be printed).
+ * @param version Version of Spark where key was deprecated.
+ * @param deprecationMessage Message to include in the deprecation warning; mandatory when
+ * `newName` is not provided.
+ */
+ private case class DeprecatedConfig(
+ oldName: String,
+ _newName: String,
+ version: String,
+ deprecationMessage: String = null) {
+
+ private val warned = new AtomicBoolean(false)
+ val newName = Option(_newName)
+
+ if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) {
+ throw new IllegalArgumentException("Need new config name or deprecation message.")
+ }
+
+ def warn(): Unit = {
+ if (warned.compareAndSet(false, true)) {
+ if (newName != null) {
+ val message = Option(deprecationMessage).getOrElse(
+ s"Please use the alternative '$newName' instead.")
+ logWarning(
+ s"The configuration option '$oldName' has been replaced as of Spark $version and " +
+ s"may be removed in the future. $message")
+ } else {
+ logWarning(
+ s"The configuration option '$oldName' has been deprecated as of Spark $version and " +
+ s"may be removed in the future. $deprecationMessage")
+ }
+ }
+ }
+
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index be081c3825566..35b324ba6f573 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -17,12 +17,13 @@
package org.apache.spark
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.util.jar.{JarEntry, JarOutputStream}
import scala.collection.JavaConversions._
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
@@ -59,6 +60,22 @@ private[spark] object TestUtils {
createJar(files1 ++ files2, jarFile)
}
+ /**
+ * Create a jar file containing multiple files. The `files` map contains a mapping of
+ * file names in the jar file to their contents.
+ */
+ def createJarWithFiles(files: Map[String, String], dir: File = null): URL = {
+ val tempDir = Option(dir).getOrElse(Utils.createTempDir())
+ val jarFile = File.createTempFile("testJar", ".jar", tempDir)
+ val jarStream = new JarOutputStream(new FileOutputStream(jarFile))
+ files.foreach { case (k, v) =>
+ val entry = new JarEntry(k)
+ jarStream.putNextEntry(entry)
+ ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream)
+ }
+ jarStream.close()
+ jarFile.toURI.toURL
+ }
/**
* Create a jar file that contains this set of files. All files will be located at the root
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 38b3da0b13756..237d26fc6bd0e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -68,8 +68,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val sparkJavaOpts = Utils.sparkJavaOpts(conf)
val javaOpts = sparkJavaOpts ++ extraJavaOpts
- val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++
- driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts)
+ val command = new Command(mainClass,
+ Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions,
+ sys.env, classPathEntries, libraryPathEntries, javaOpts)
val driverDescription = new DriverDescription(
driverArgs.jarUrl,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 6d213926f3d7b..c4bc5054d61a1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -37,7 +37,7 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
import org.apache.spark.deploy.rest._
import org.apache.spark.executor._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
/**
* Whether to submit, kill, or request the status of an application.
@@ -467,11 +467,11 @@ object SparkSubmit {
}
val loader =
- if (sysProps.getOrElse("spark.files.userClassPathFirst", "false").toBoolean) {
- new ChildExecutorURLClassLoader(new Array[URL](0),
+ if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
+ new ChildFirstURLClassLoader(new Array[URL](0),
Thread.currentThread.getContextClassLoader)
} else {
- new ExecutorURLClassLoader(new Array[URL](0),
+ new MutableURLClassLoader(new Array[URL](0),
Thread.currentThread.getContextClassLoader)
}
Thread.currentThread.setContextClassLoader(loader)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index b47a081053e77..fd514f07664a9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -196,7 +196,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
{Utils.megabytesToString(driver.desc.mem.toLong)}
|
- {driver.desc.command.arguments(1)} |
+ {driver.desc.command.arguments(2)} |
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index 2033d67e1f394..6e4486e20fcba 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -392,7 +392,7 @@ private class SubmitRequestServlet(
val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = new Command(
"org.apache.spark.deploy.worker.DriverWrapper",
- Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper
+ Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper
environmentVariables, extraClassPath, extraLibraryPath, javaOpts)
val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY)
val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 28cab36c7b9e2..b964a09bdb218 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -74,10 +74,15 @@ private[spark] class DriverRunner(
val driverDir = createWorkingDirectory()
val localJarFilename = downloadUserJar(driverDir)
- // Make sure user application jar is on the classpath
+ def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
+ case "{{USER_JAR}}" => localJarFilename
+ case other => other
+ }
+
// TODO: If we add ability to submit multiple jars they should also be added here
val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem,
- sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename))
+ sparkHome.getAbsolutePath, substituteVariables)
launchDriver(builder, driverDir, driverDesc.supervise)
}
catch {
@@ -111,12 +116,6 @@ private[spark] class DriverRunner(
}
}
- /** Replace variables in a command argument passed to us */
- private def substituteVariables(argument: String): String = argument match {
- case "{{WORKER_URL}}" => workerUrl
- case other => other
- }
-
/**
* Creates the working directory for this driver.
* Will throw an exception if there are errors preparing the directory.
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index 05e242e6df702..ab467a5ee8c6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -17,10 +17,12 @@
package org.apache.spark.deploy.worker
+import java.io.File
+
import akka.actor._
import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
/**
* Utility object for launching driver programs such that they share fate with the Worker process.
@@ -28,21 +30,31 @@ import org.apache.spark.util.{AkkaUtils, Utils}
object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
- case workerUrl :: mainClass :: extraArgs =>
+ case workerUrl :: userJar :: mainClass :: extraArgs =>
val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
Utils.localHostName(), 0, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
+ val currentLoader = Thread.currentThread.getContextClassLoader
+ val userJarUrl = new File(userJar).toURI().toURL()
+ val loader =
+ if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
+ new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader)
+ } else {
+ new MutableURLClassLoader(Array(userJarUrl), currentLoader)
+ }
+ Thread.currentThread.setContextClassLoader(loader)
+
// Delegate to supplied main class
- val clazz = Class.forName(args(1))
+ val clazz = Class.forName(mainClass, true, loader)
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])
actorSystem.shutdown()
case _ =>
- System.err.println("Usage: DriverWrapper [options]")
+ System.err.println("Usage: DriverWrapper [options]")
System.exit(-1)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 3a42f8b157977..dd19e4947db1e 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -17,8 +17,10 @@
package org.apache.spark.executor
+import java.net.URL
import java.nio.ByteBuffer
+import scala.collection.mutable
import scala.concurrent.Await
import akka.actor.{Actor, ActorSelection, Props}
@@ -38,6 +40,7 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
+ userClassPath: Seq[URL],
env: SparkEnv)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
@@ -63,7 +66,7 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
- executor = new Executor(executorId, hostname, env, isLocal = false)
+ executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -117,7 +120,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
hostname: String,
cores: Int,
appId: String,
- workerUrl: Option[String]) {
+ workerUrl: Option[String],
+ userClassPath: Seq[URL]) {
SignalLogger.register(log)
@@ -162,7 +166,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val sparkHostPort = hostname + ":" + boundPort
env.actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, env),
+ driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
name = "Executor")
workerUrl.foreach { url =>
env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
@@ -172,20 +176,69 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
def main(args: Array[String]) {
- args.length match {
- case x if x < 5 =>
- System.err.println(
+ var driverUrl: String = null
+ var executorId: String = null
+ var hostname: String = null
+ var cores: Int = 0
+ var appId: String = null
+ var workerUrl: Option[String] = None
+ val userClassPath = new mutable.ListBuffer[URL]()
+
+ var argv = args.toList
+ while (!argv.isEmpty) {
+ argv match {
+ case ("--driver-url") :: value :: tail =>
+ driverUrl = value
+ argv = tail
+ case ("--executor-id") :: value :: tail =>
+ executorId = value
+ argv = tail
+ case ("--hostname") :: value :: tail =>
+ hostname = value
+ argv = tail
+ case ("--cores") :: value :: tail =>
+ cores = value.toInt
+ argv = tail
+ case ("--app-id") :: value :: tail =>
+ appId = value
+ argv = tail
+ case ("--worker-url") :: value :: tail =>
// Worker url is used in spark standalone mode to enforce fate-sharing with worker
- "Usage: CoarseGrainedExecutorBackend " +
- " [] ")
- System.exit(1)
+ workerUrl = Some(value)
+ argv = tail
+ case ("--user-class-path") :: value :: tail =>
+ userClassPath += new URL(value)
+ argv = tail
+ case Nil =>
+ case tail =>
+ System.err.println(s"Unrecognized options: ${tail.mkString(" ")}")
+ printUsageAndExit()
+ }
+ }
- // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode)
- // and CoarseMesosSchedulerBackend (for mesos mode).
- case 5 =>
- run(args(0), args(1), args(2), args(3).toInt, args(4), None)
- case x if x > 5 =>
- run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5)))
+ if (driverUrl == null || executorId == null || hostname == null || cores <= 0 ||
+ appId == null) {
+ printUsageAndExit()
}
+
+ run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
}
+
+ private def printUsageAndExit() = {
+ System.err.println(
+ """
+ |"Usage: CoarseGrainedExecutorBackend [options]
+ |
+ | Options are:
+ | --driver-url
+ | --executor-id
+ | --hostname
+ | --cores
+ | --app-id
+ | --worker-url
+ | --user-class-path
+ |""".stripMargin)
+ System.exit(1)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5141483d1e745..6b22dcd6f5cbf 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -19,6 +19,7 @@ package org.apache.spark.executor
import java.io.File
import java.lang.management.ManagementFactory
+import java.net.URL
import java.nio.ByteBuffer
import java.util.concurrent._
@@ -33,7 +34,8 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
+ SparkUncaughtExceptionHandler, AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
@@ -43,6 +45,7 @@ private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
+ userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
extends Logging
{
@@ -288,17 +291,23 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
+ // Bootstrap the list of jars with the user class path.
+ val now = System.currentTimeMillis()
+ userClassPath.foreach { url =>
+ currentJars(url.getPath().split("/").last) = now
+ }
+
val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
- val urls = currentJars.keySet.map { uri =>
+ val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
- }.toArray
- val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
- userClassPathFirst match {
- case true => new ChildExecutorURLClassLoader(urls, currentLoader)
- case false => new ExecutorURLClassLoader(urls, currentLoader)
+ }
+ if (conf.getBoolean("spark.executor.userClassPathFirst", false)) {
+ new ChildFirstURLClassLoader(urls, currentLoader)
+ } else {
+ new MutableURLClassLoader(urls, currentLoader)
}
}
@@ -311,7 +320,7 @@ private[spark] class Executor(
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
val userClassPathFirst: java.lang.Boolean =
- conf.getBoolean("spark.files.userClassPathFirst", false)
+ conf.getBoolean("spark.executor.userClassPathFirst", false)
try {
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
@@ -344,18 +353,23 @@ private[spark] class Executor(
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- // Fetch file with useCache mode, close cache for local mode.
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
- env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
- currentJars(name) = timestamp
- // Add it to our class loader
+ for ((name, timestamp) <- newJars) {
val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ val currentTimeStamp = currentJars.get(name)
+ .orElse(currentJars.get(localName))
+ .getOrElse(-1L)
+ if (currentTimeStamp < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
deleted file mode 100644
index 8011e75944aac..0000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import java.net.{URLClassLoader, URL}
-
-import org.apache.spark.util.ParentClassLoader
-
-/**
- * The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
- * We also make changes so user classes can come before the default classes.
- */
-
-private[spark] trait MutableURLClassLoader extends ClassLoader {
- def addURL(url: URL)
- def getURLs: Array[URL]
-}
-
-private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
- extends MutableURLClassLoader {
-
- private object userClassLoader extends URLClassLoader(urls, null){
- override def addURL(url: URL) {
- super.addURL(url)
- }
- override def findClass(name: String): Class[_] = {
- val loaded = super.findLoadedClass(name)
- if (loaded != null) {
- return loaded
- }
- try {
- super.findClass(name)
- } catch {
- case e: ClassNotFoundException => {
- parentClassLoader.loadClass(name)
- }
- }
- }
- }
-
- private val parentClassLoader = new ParentClassLoader(parent)
-
- override def findClass(name: String): Class[_] = {
- try {
- userClassLoader.findClass(name)
- } catch {
- case e: ClassNotFoundException => {
- parentClassLoader.loadClass(name)
- }
- }
- }
-
- def addURL(url: URL) {
- userClassLoader.addURL(url)
- }
-
- def getURLs() = {
- userClassLoader.getURLs()
- }
-}
-
-private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
- extends URLClassLoader(urls, parent) with MutableURLClassLoader {
-
- override def addURL(url: URL) {
- super.addURL(url)
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index d2e1680a5fd1b..40fc6b59cdf7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -52,8 +52,13 @@ private[spark] class SparkDeploySchedulerBackend(
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}",
- "{{WORKER_URL}}")
+ val args = Seq(
+ "--driver-url", driverUrl,
+ "--executor-id", "{{EXECUTOR_ID}}",
+ "--hostname", "{{HOSTNAME}}",
+ "--cores", "{{CORES}}",
+ "--app-id", "{{APP_ID}}",
+ "--worker-url", "{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 0d1c2a916ca7f..90dfe14352a8e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -154,18 +154,25 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
- prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores, appId))
+ "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
+ .format(prefixEnv, runScript) +
+ s" --driver-url $driverUrl" +
+ s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --hostname ${offer.getHostname}" +
+ s" --cores $numCores" +
+ s" --app-id $appId")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- ("cd %s*; %s " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
- .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores, appId))
+ s"cd $basename*; $prefixEnv " +
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" +
+ s" --driver-url $driverUrl" +
+ s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --hostname ${offer.getHostname}" +
+ s" --cores $numCores" +
+ s" --app-id $appId")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
command.build()
diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
new file mode 100644
index 0000000000000..d9c7103b2f3bf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.net.{URLClassLoader, URL}
+import java.util.Enumeration
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.util.ParentClassLoader
+
+/**
+ * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader.
+ */
+private[spark] class MutableURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends URLClassLoader(urls, parent) {
+
+ override def addURL(url: URL): Unit = {
+ super.addURL(url)
+ }
+
+ override def getURLs(): Array[URL] = {
+ super.getURLs()
+ }
+
+}
+
+/**
+ * A mutable class loader that gives preference to its own URLs over the parent class loader
+ * when loading classes and resources.
+ */
+private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends MutableURLClassLoader(urls, null) {
+
+ private val parentClassLoader = new ParentClassLoader(parent)
+
+ /**
+ * Used to implement fine-grained class loading locks similar to what is done by Java 7. This
+ * prevents deadlock issues when using non-hierarchical class loaders.
+ *
+ * Note that due to Java 6 compatibility (and some issues with implementing class loaders in
+ * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called.
+ */
+ private val locks = new ConcurrentHashMap[String, Object]()
+
+ override def loadClass(name: String, resolve: Boolean): Class[_] = {
+ var lock = locks.get(name)
+ if (lock == null) {
+ val newLock = new Object()
+ lock = locks.putIfAbsent(name, newLock)
+ if (lock == null) {
+ lock = newLock
+ }
+ }
+
+ lock.synchronized {
+ try {
+ super.loadClass(name, resolve)
+ } catch {
+ case e: ClassNotFoundException =>
+ parentClassLoader.loadClass(name, resolve)
+ }
+ }
+ }
+
+ override def getResource(name: String): URL = {
+ val url = super.findResource(name)
+ val res = if (url != null) url else parentClassLoader.getResource(name)
+ res
+ }
+
+ override def getResources(name: String): Enumeration[URL] = {
+ val urls = super.findResources(name)
+ val res =
+ if (urls != null && urls.hasMoreElements()) {
+ urls
+ } else {
+ parentClassLoader.getResources(name)
+ }
+ res
+ }
+
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
index 3abc12681fe9a..6d8d9e8da3678 100644
--- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
/**
- * A class loader which makes findClass accesible to the child
+ * A class loader which makes some protected methods in ClassLoader accesible.
*/
private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) {
@@ -29,4 +29,9 @@ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(
override def loadClass(name: String): Class[_] = {
super.loadClass(name)
}
+
+ override def loadClass(name: String, resolve: Boolean): Class[_] = {
+ super.loadClass(name, resolve)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index e08210ae60d17..ea6b73bc68b34 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -197,6 +197,18 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro
serializer.newInstance().serialize(new StringBuffer())
}
+ test("deprecated config keys") {
+ val conf = new SparkConf()
+ .set("spark.files.userClassPathFirst", "true")
+ .set("spark.yarn.user.classpath.first", "true")
+ assert(conf.contains("spark.files.userClassPathFirst"))
+ assert(conf.contains("spark.executor.userClassPathFirst"))
+ assert(conf.contains("spark.yarn.user.classpath.first"))
+ assert(conf.getBoolean("spark.files.userClassPathFirst", false))
+ assert(conf.getBoolean("spark.executor.userClassPathFirst", false))
+ assert(conf.getBoolean("spark.yarn.user.classpath.first", false))
+ }
+
}
class Class1 {}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 1ddccae1262bc..46d745c4ecbfa 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -21,6 +21,8 @@ import java.io._
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.ByteStreams
import org.scalatest.FunSuite
import org.scalatest.Matchers
import org.scalatest.concurrent.Timeouts
@@ -450,6 +452,19 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
}
+ test("user classpath first in driver") {
+ val systemJar = TestUtils.createJarWithFiles(Map("test.resource" -> "SYSTEM"))
+ val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"))
+ val args = Seq(
+ "--class", UserClasspathFirstTest.getClass.getName.stripSuffix("$"),
+ "--name", "testApp",
+ "--master", "local",
+ "--conf", "spark.driver.extraClassPath=" + systemJar,
+ "--conf", "spark.driver.userClassPathFirst=true",
+ userJar.toString)
+ runSparkSubmit(args)
+ }
+
test("SPARK_CONF_DIR overrides spark-defaults.conf") {
forConfDir(Map("spark.executor.memory" -> "2.3g")) { path =>
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
@@ -541,3 +556,15 @@ object SimpleApplicationTest {
}
}
}
+
+object UserClasspathFirstTest {
+ def main(args: Array[String]) {
+ val ccl = Thread.currentThread().getContextClassLoader()
+ val resource = ccl.getResourceAsStream("test.resource")
+ val bytes = ByteStreams.toByteArray(resource)
+ val contents = new String(bytes, 0, bytes.length, UTF_8)
+ if (contents != "USER") {
+ throw new SparkException("Should have read user resource, but instead read: " + contents)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
similarity index 90%
rename from core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala
rename to core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
index b7912c09d1410..31e3b7e7bb71b 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.executor
+package org.apache.spark.util
import java.net.URLClassLoader
@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
import org.apache.spark.util.Utils
-class ExecutorURLClassLoaderSuite extends FunSuite {
+class MutableURLClassLoaderSuite extends FunSuite {
val urls2 = List(TestUtils.createJarWithClasses(
classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"),
@@ -37,7 +37,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
test("child first") {
val parentLoader = new URLClassLoader(urls2, null)
- val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader)
+ val classLoader = new ChildFirstURLClassLoader(urls, parentLoader)
val fakeClass = classLoader.loadClass("FakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
@@ -47,7 +47,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
test("parent first") {
val parentLoader = new URLClassLoader(urls2, null)
- val classLoader = new ExecutorURLClassLoader(urls, parentLoader)
+ val classLoader = new MutableURLClassLoader(urls, parentLoader)
val fakeClass = classLoader.loadClass("FakeClass1").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2")
@@ -57,7 +57,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
test("child first can fall back") {
val parentLoader = new URLClassLoader(urls2, null)
- val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader)
+ val classLoader = new ChildFirstURLClassLoader(urls, parentLoader)
val fakeClass = classLoader.loadClass("FakeClass3").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2")
@@ -65,7 +65,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
test("child first can fail") {
val parentLoader = new URLClassLoader(urls2, null)
- val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader)
+ val classLoader = new ChildFirstURLClassLoader(urls, parentLoader)
intercept[java.lang.ClassNotFoundException] {
classLoader.loadClass("FakeClassDoesNotExist").newInstance()
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 00e973c245005..eb0d6d33c97d9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -230,6 +230,15 @@ Apart from these, the following properties are also available, and may be useful
Set a special library path to use when launching the driver JVM.
+
+ spark.driver.userClassPathFirst |
+ false |
+
+ (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading
+ classes in the the driver. This feature can be used to mitigate conflicts between Spark's
+ dependencies and user dependencies. It is currently an experimental feature.
+ |
+
spark.executor.extraJavaOptions |
(none) |
@@ -297,13 +306,11 @@ Apart from these, the following properties are also available, and may be useful
- spark.files.userClassPathFirst |
+ spark.executor.userClassPathFirst |
false |
- (Experimental) Whether to give user-added jars precedence over Spark's own jars when
- loading classes in Executors. This feature can be used to mitigate conflicts between
- Spark's dependencies and user dependencies. It is currently an experimental feature.
- (Currently, this setting does not work for YARN, see SPARK-2996 for more details).
+ (Experimental) Same functionality as spark.driver.userClassPathFirst , but
+ applied to executor instances.
|
@@ -865,8 +872,8 @@ Apart from these, the following properties are also available, and may be useful
spark.network.timeout |
120 |
- Default timeout for all network interactions, in seconds. This config will be used in
- place of spark.core.connection.ack.wait.timeout , spark.akka.timeout ,
+ Default timeout for all network interactions, in seconds. This config will be used in
+ place of spark.core.connection.ack.wait.timeout , spark.akka.timeout ,
spark.storage.blockManagerSlaveTimeoutMs or
spark.shuffle.io.connectionTimeout , if they are not configured.
|
@@ -911,8 +918,8 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.io.preferDirectBufs |
true |
- (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache
- block transfer. For environments where off-heap memory is tightly limited, users may wish to
+ (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache
+ block transfer. For environments where off-heap memory is tightly limited, users may wish to
turn this off to force all allocations from Netty to be on-heap.
|
@@ -920,7 +927,7 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.io.numConnectionsPerPeer |
1 |
- (Netty only) Connections between hosts are reused in order to reduce connection buildup for
+ (Netty only) Connections between hosts are reused in order to reduce connection buildup for
large clusters. For clusters with many hard disks and few hosts, this may result in insufficient
concurrency to saturate all disks, and so users may consider increasing this value.
|
@@ -930,7 +937,7 @@ Apart from these, the following properties are also available, and may be useful
3 |
(Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is
- set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC
+ set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC
pauses or transient network connectivity issues.
|
@@ -939,7 +946,7 @@ Apart from these, the following properties are also available, and may be useful
5 |
(Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying
- is simply maxRetries * retryWait , by default 15 seconds.
+ is simply maxRetries * retryWait , by default 15 seconds.
|
diff --git a/pom.xml b/pom.xml
index f6f176d2004b7..a9e968af25453 100644
--- a/pom.xml
+++ b/pom.xml
@@ -342,7 +342,7 @@
-
+
@@ -395,7 +395,7 @@
provided
-
+
org.apache.commons
commons-lang3
@@ -1178,13 +1178,19 @@
${project.build.directory}/surefire-reports
-Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+
+
+ ${test_classpath}
+
true
${session.executionRootDirectory}
1
false
false
- ${test_classpath}
true
false
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 95f8dfa3d270f..8fb1239b4a96b 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -411,6 +411,10 @@ object TestSettings {
lazy val settings = Seq (
// Fork new JVMs for tests and set Java options for those
fork := true,
+ // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes
+ // launched by the tests have access to the correct test-time classpath.
+ envVars in Test += ("SPARK_DIST_CLASSPATH" ->
+ (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")),
javaOptions in Test += "-Dspark.test.home=" + sparkHome,
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
@@ -423,10 +427,6 @@ object TestSettings {
javaOptions in Test += "-ea",
javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
.split(" ").toSeq,
- // This places test scope jars on the classpath of executors during tests.
- javaOptions in Test +=
- "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files.
- map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
javaOptions += "-Xmx3g",
// Show full stack trace and duration in test cases.
testOptions in Test += Tests.Argument("-oDF"),
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 4cc320c5d59b5..a9bf861d160c1 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn
import scala.util.control.NonFatal
-import java.io.IOException
+import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
-import java.net.Socket
+import java.net.{Socket, URL}
import java.util.concurrent.atomic.AtomicReference
import akka.actor._
@@ -38,7 +38,8 @@ import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.scheduler.cluster.YarnSchedulerBackend
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader,
+ SignalLogger, Utils}
/**
* Common application master functionality for Spark on Yarn.
@@ -244,7 +245,6 @@ private[spark] class ApplicationMaster(
host: String,
port: String,
isClusterMode: Boolean): Unit = {
-
val driverUrl = AkkaUtils.address(
AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
@@ -453,12 +453,24 @@ private[spark] class ApplicationMaster(
private def startUserApplication(): Thread = {
logInfo("Starting the user application in a separate Thread")
System.setProperty("spark.executor.instances", args.numExecutors.toString)
+
+ val classpath = Client.getUserClasspath(sparkConf)
+ val urls = classpath.map { entry =>
+ new URL("file:" + new File(entry.getPath()).getAbsolutePath())
+ }
+ val userClassLoader =
+ if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
+ new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ } else {
+ new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ }
+
if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
System.setProperty("spark.submit.pyFiles",
PythonRunner.formatPaths(args.pyFiles).mkString(","))
}
- val mainMethod = Class.forName(args.userClass, false,
- Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
+ val mainMethod = userClassLoader.loadClass(args.userClass)
+ .getMethod("main", classOf[Array[String]])
val userThread = new Thread {
override def run() {
@@ -483,6 +495,7 @@ private[spark] class ApplicationMaster(
}
}
}
+ userThread.setContextClassLoader(userClassLoader)
userThread.setName("Driver")
userThread.start()
userThread
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8afc1ccdad732..46d9df93488cb 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -183,8 +183,7 @@ private[spark] class Client(
private[yarn] def copyFileToRemote(
destDir: Path,
srcPath: Path,
- replication: Short,
- setPerms: Boolean = false): Path = {
+ replication: Short): Path = {
val destFs = destDir.getFileSystem(hadoopConf)
val srcFs = srcPath.getFileSystem(hadoopConf)
var destPath = srcPath
@@ -193,9 +192,7 @@ private[spark] class Client(
logInfo(s"Uploading resource $srcPath -> $destPath")
FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf)
destFs.setReplication(destPath, replication)
- if (setPerms) {
- destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
- }
+ destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
} else {
logInfo(s"Source and destination file systems are the same. Not copying $srcPath")
}
@@ -239,23 +236,22 @@ private[spark] class Client(
/**
* Copy the given main resource to the distributed cache if the scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
- * Each resource is represented by a 4-tuple of:
+ * Each resource is represented by a 3-tuple of:
* (1) destination resource name,
* (2) local path to the resource,
- * (3) Spark property key to set if the scheme is not local, and
- * (4) whether to set permissions for this resource
+ * (3) Spark property key to set if the scheme is not local
*/
List(
- (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false),
- (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true),
- ("log4j.properties", oldLog4jConf.orNull, null, false)
- ).foreach { case (destName, _localPath, confKey, setPermissions) =>
+ (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR),
+ (APP_JAR, args.userJar, CONF_SPARK_USER_JAR),
+ ("log4j.properties", oldLog4jConf.orNull, null)
+ ).foreach { case (destName, _localPath, confKey) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (!localPath.isEmpty()) {
val localURI = new URI(localPath)
if (localURI.getScheme != LOCAL_SCHEME) {
val src = getQualifiedLocalPath(localURI, hadoopConf)
- val destPath = copyFileToRemote(dst, src, replication, setPermissions)
+ val destPath = copyFileToRemote(dst, src, replication)
val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
distCacheMgr.addResource(destFs, hadoopConf, destPath,
localResources, LocalResourceType.FILE, destName, statCache)
@@ -707,7 +703,7 @@ object Client extends Logging {
* Return the path to the given application's staging directory.
*/
private def getAppStagingDir(appId: ApplicationId): String = {
- SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ buildPath(SPARK_STAGING, appId.toString())
}
/**
@@ -783,7 +779,13 @@ object Client extends Logging {
/**
* Populate the classpath entry in the given environment map.
- * This includes the user jar, Spark jar, and any extra application jars.
+ *
+ * User jars are generally not added to the JVM's system classpath; those are handled by the AM
+ * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars
+ * are included in the system classpath, though. The extra class path and other uploaded files are
+ * always made available through the system class path.
+ *
+ * @param args Client arguments (when starting the AM) or null (when starting executors).
*/
private[yarn] def populateClasspath(
args: ClientArguments,
@@ -795,48 +797,38 @@ object Client extends Logging {
addClasspathEntry(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env
)
-
- // Normally the users app.jar is last in case conflicts with spark jars
if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
- addUserClasspath(args, sparkConf, env)
- addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
- populateHadoopClasspath(conf, env)
- } else {
- addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
- populateHadoopClasspath(conf, env)
- addUserClasspath(args, sparkConf, env)
+ val userClassPath =
+ if (args != null) {
+ getUserClasspath(Option(args.userJar), Option(args.addJars))
+ } else {
+ getUserClasspath(sparkConf)
+ }
+ userClassPath.foreach { x =>
+ addFileToClasspath(x, null, env)
+ }
}
-
- // Append all jar files under the working directory to the classpath.
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + "*", env
- )
+ addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env)
+ populateHadoopClasspath(conf, env)
+ sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env))
}
/**
- * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly
- * to the classpath.
+ * Returns a list of URIs representing the user classpath.
+ *
+ * @param conf Spark configuration.
*/
- private def addUserClasspath(
- args: ClientArguments,
- conf: SparkConf,
- env: HashMap[String, String]): Unit = {
-
- // If `args` is not null, we are launching an AM container.
- // Otherwise, we are launching executor containers.
- val (mainJar, secondaryJars) =
- if (args != null) {
- (args.userJar, args.addJars)
- } else {
- (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null))
- }
+ def getUserClasspath(conf: SparkConf): Array[URI] = {
+ getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR),
+ conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS))
+ }
- addFileToClasspath(mainJar, APP_JAR, env)
- if (secondaryJars != null) {
- secondaryJars.split(",").filter(_.nonEmpty).foreach { jar =>
- addFileToClasspath(jar, null, env)
- }
- }
+ private def getUserClasspath(
+ mainJar: Option[String],
+ secondaryJars: Option[String]): Array[URI] = {
+ val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_))
+ val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_))
+ (mainUri ++ secondaryUris).toArray
}
/**
@@ -847,27 +839,19 @@ object Client extends Logging {
*
* If not a "local:" file and no alternate name, the environment is not modified.
*
- * @param path Path to add to classpath (optional).
+ * @param uri URI to add to classpath (optional).
* @param fileName Alternate name for the file (optional).
* @param env Map holding the environment variables.
*/
private def addFileToClasspath(
- path: String,
+ uri: URI,
fileName: String,
env: HashMap[String, String]): Unit = {
- if (path != null) {
- scala.util.control.Exception.ignoring(classOf[URISyntaxException]) {
- val uri = new URI(path)
- if (uri.getScheme == LOCAL_SCHEME) {
- addClasspathEntry(uri.getPath, env)
- return
- }
- }
- }
- if (fileName != null) {
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + fileName, env
- )
+ if (uri != null && uri.getScheme == LOCAL_SCHEME) {
+ addClasspathEntry(uri.getPath, env)
+ } else if (fileName != null) {
+ addClasspathEntry(buildPath(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env)
}
}
@@ -963,4 +947,23 @@ object Client extends Logging {
new Path(qualifiedURI)
}
+ /**
+ * Whether to consider jars provided by the user to have precedence over the Spark jars when
+ * loading user classes.
+ */
+ def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = {
+ if (isDriver) {
+ conf.getBoolean("spark.driver.userClassPathFirst", false)
+ } else {
+ conf.getBoolean("spark.executor.userClassPathFirst", false)
+ }
+ }
+
+ /**
+ * Joins all the path components using Path.SEPARATOR.
+ */
+ def buildPath(components: String*): String = {
+ components.mkString(Path.SEPARATOR)
+ }
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 7cd8c5f0f9204..6d5b8fda76ab8 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.yarn
+import java.io.File
import java.net.URI
import java.nio.ByteBuffer
@@ -57,7 +58,7 @@ class ExecutorRunnable(
var nmClient: NMClient = _
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
lazy val env = prepareEnvironment(container)
-
+
def run = {
logInfo("Starting Executor Container")
nmClient = NMClient.createNMClient()
@@ -185,6 +186,16 @@ class ExecutorRunnable(
// For log4j configuration to reference
javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+ val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri =>
+ val absPath =
+ if (new File(uri.getPath()).isAbsolute()) {
+ uri.getPath()
+ } else {
+ Client.buildPath(Environment.PWD.$(), uri.getPath())
+ }
+ Seq("--user-class-path", "file:" + absPath)
+ }.toSeq
+
val commands = prefixEnv ++ Seq(
YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java",
"-server",
@@ -196,11 +207,13 @@ class ExecutorRunnable(
"-XX:OnOutOfMemoryError='kill %p'") ++
javaOpts ++
Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
- masterAddress.toString,
- slaveId.toString,
- hostname.toString,
- executorCores.toString,
- appId,
+ "--driver-url", masterAddress.toString,
+ "--executor-id", slaveId.toString,
+ "--hostname", hostname.toString,
+ "--cores", executorCores.toString,
+ "--app-id", appId) ++
+ userClassPath ++
+ Seq(
"1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
"2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties
index 287c8e3563503..aab41fa49430f 100644
--- a/yarn/src/test/resources/log4j.properties
+++ b/yarn/src/test/resources/log4j.properties
@@ -16,7 +16,7 @@
#
# Set everything to be logged to the file target/unit-tests.log
-log4j.rootCategory=INFO, file
+log4j.rootCategory=DEBUG, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
@@ -25,4 +25,4 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
-org.eclipse.jetty.LEVEL=WARN
+log4j.logger.org.apache.hadoop=WARN
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 2bb3dcffd61d9..f8f8129d220e4 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -82,6 +82,7 @@ class ClientSuite extends FunSuite with Matchers {
test("Local jar URIs") {
val conf = new Configuration()
val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK)
+ .set("spark.yarn.user.classpath.first", "true")
val env = new MutableHashMap[String, String]()
val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
@@ -98,13 +99,10 @@ class ClientSuite extends FunSuite with Matchers {
})
if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
cp should contain("{{PWD}}")
- cp should contain(s"{{PWD}}${Path.SEPARATOR}*")
} else if (Utils.isWindows) {
cp should contain("%PWD%")
- cp should contain(s"%PWD%${Path.SEPARATOR}*")
} else {
cp should contain(Environment.PWD.$())
- cp should contain(s"${Environment.PWD.$()}${File.separator}*")
}
cp should not contain (Client.SPARK_JAR)
cp should not contain (Client.APP_JAR)
@@ -117,7 +115,7 @@ class ClientSuite extends FunSuite with Matchers {
val client = spy(new Client(args, conf, sparkConf))
doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
- any(classOf[Path]), anyShort(), anyBoolean())
+ any(classOf[Path]), anyShort())
val tempDir = Utils.createTempDir()
try {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index e39de82740b1d..0e37276ba724b 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -17,27 +17,34 @@
package org.apache.spark.deploy.yarn
-import java.io.File
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.util.Properties
import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
import scala.collection.mutable
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.ByteStreams
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils}
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded}
import org.apache.spark.util.Utils
+/**
+ * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN
+ * applications, and require the Spark assembly to be built before they can be successfully
+ * run.
+ */
class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
- // log4j configuration for the Yarn containers, so that their output is collected
- // by Yarn instead of trying to overwrite unit-tests.log.
+ // log4j configuration for the YARN containers, so that their output is collected
+ // by YARN instead of trying to overwrite unit-tests.log.
private val LOG4J_CONF = """
|log4j.rootCategory=DEBUG, console
|log4j.appender.console=org.apache.log4j.ConsoleAppender
@@ -52,13 +59,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
|
|from pyspark import SparkConf , SparkContext
|if __name__ == "__main__":
- | if len(sys.argv) != 3:
- | print >> sys.stderr, "Usage: test.py [master] [result file]"
+ | if len(sys.argv) != 2:
+ | print >> sys.stderr, "Usage: test.py [result file]"
| exit(-1)
- | conf = SparkConf()
- | conf.setMaster(sys.argv[1]).setAppName("python test in yarn cluster mode")
- | sc = SparkContext(conf=conf)
- | status = open(sys.argv[2],'w')
+ | sc = SparkContext(conf=SparkConf())
+ | status = open(sys.argv[1],'w')
| result = "failure"
| rdd = sc.parallelize(range(10))
| cnt = rdd.count()
@@ -72,23 +77,17 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
private var yarnCluster: MiniYARNCluster = _
private var tempDir: File = _
private var fakeSparkJar: File = _
- private var oldConf: Map[String, String] = _
+ private var logConfDir: File = _
override def beforeAll() {
super.beforeAll()
tempDir = Utils.createTempDir()
-
- val logConfDir = new File(tempDir, "log4j")
+ logConfDir = new File(tempDir, "log4j")
logConfDir.mkdir()
val logConfFile = new File(logConfDir, "log4j.properties")
- Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8)
-
- val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator +
- sys.props("java.class.path")
-
- oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap
+ Files.write(LOG4J_CONF, logConfFile, UTF_8)
yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
yarnCluster.init(new YarnConfiguration())
@@ -119,99 +118,165 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
}
logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
- config.foreach { e =>
- sys.props += ("spark.hadoop." + e.getKey() -> e.getValue())
- }
fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- sys.props += ("spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
- sys.props += ("spark.executorEnv.SPARK_HOME" -> sparkHome)
- sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath()))
- sys.props += ("spark.executor.instances" -> "1")
- sys.props += ("spark.driver.extraClassPath" -> childClasspath)
- sys.props += ("spark.executor.extraClassPath" -> childClasspath)
- sys.props += ("spark.executor.extraJavaOptions" -> "-Dfoo=\"one two three\"")
- sys.props += ("spark.driver.extraJavaOptions" -> "-Dfoo=\"one two three\"")
}
override def afterAll() {
yarnCluster.stop()
- sys.props.retain { case (k, v) => !k.startsWith("spark.") }
- sys.props ++= oldConf
super.afterAll()
}
test("run Spark in yarn-client mode") {
- var result = File.createTempFile("result", null, tempDir)
- YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath()))
- checkResult(result)
-
- // verify log urls are present
- YarnClusterDriver.listener.addedExecutorInfos.values.foreach { info =>
- assert(info.logUrlMap.nonEmpty)
- }
+ testBasicYarnApp(true)
}
test("run Spark in yarn-cluster mode") {
- val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
- var result = File.createTempFile("result", null, tempDir)
-
- val args = Array("--class", main,
- "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--arg", result.getAbsolutePath(),
- "--num-executors", "1")
- Client.main(args)
- checkResult(result)
-
- // verify log urls are present.
- YarnClusterDriver.listener.addedExecutorInfos.values.foreach { info =>
- assert(info.logUrlMap.nonEmpty)
- }
+ testBasicYarnApp(false)
}
test("run Spark in yarn-cluster mode unsuccessfully") {
- val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
-
- // Use only one argument so the driver will fail
- val args = Array("--class", main,
- "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--num-executors", "1")
+ // Don't provide arguments so the driver will fail.
val exception = intercept[SparkException] {
- Client.main(args)
+ runSpark(false, mainClassName(YarnClusterDriver.getClass))
+ fail("Spark application should have failed.")
}
- assert(Utils.exceptionString(exception).contains("Application finished with failed status"))
}
test("run Python application in yarn-cluster mode") {
val primaryPyFile = new File(tempDir, "test.py")
- Files.write(TEST_PYFILE, primaryPyFile, Charsets.UTF_8)
+ Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
val pyFile = new File(tempDir, "test2.py")
- Files.write(TEST_PYFILE, pyFile, Charsets.UTF_8)
+ Files.write(TEST_PYFILE, pyFile, UTF_8)
var result = File.createTempFile("result", null, tempDir)
- val args = Array("--class", "org.apache.spark.deploy.PythonRunner",
- "--primary-py-file", primaryPyFile.getAbsolutePath(),
- "--py-files", pyFile.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--arg", result.getAbsolutePath(),
- "--name", "python test in yarn-cluster mode",
- "--num-executors", "1")
- Client.main(args)
+ // The sbt assembly does not include pyspark / py4j python dependencies, so we need to
+ // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala.
+ val sparkHome = sys.props("spark.test.home")
+ val extraConf = Map(
+ "spark.executorEnv.SPARK_HOME" -> sparkHome,
+ "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
+
+ runSpark(false, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()),
+ appArgs = Seq(result.getAbsolutePath()),
+ extraConf = extraConf)
checkResult(result)
}
+ test("user class path first in client mode") {
+ testUseClassPathFirst(true)
+ }
+
+ test("user class path first in cluster mode") {
+ testUseClassPathFirst(false)
+ }
+
+ private def testBasicYarnApp(clientMode: Boolean): Unit = {
+ var result = File.createTempFile("result", null, tempDir)
+ runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath()))
+ checkResult(result)
+ }
+
+ private def testUseClassPathFirst(clientMode: Boolean): Unit = {
+ // Create a jar file that contains a different version of "test.resource".
+ val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
+ val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir)
+ val driverResult = File.createTempFile("driver", null, tempDir)
+ val executorResult = File.createTempFile("executor", null, tempDir)
+ runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
+ extraClassPath = Seq(originalJar.getPath()),
+ extraJars = Seq("local:" + userJar.getPath()),
+ extraConf = Map(
+ "spark.driver.userClassPathFirst" -> "true",
+ "spark.executor.userClassPathFirst" -> "true"))
+ checkResult(driverResult, "OVERRIDDEN")
+ checkResult(executorResult, "OVERRIDDEN")
+ }
+
+ private def runSpark(
+ clientMode: Boolean,
+ klass: String,
+ appArgs: Seq[String] = Nil,
+ sparkArgs: Seq[String] = Nil,
+ extraClassPath: Seq[String] = Nil,
+ extraJars: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): Unit = {
+ val master = if (clientMode) "yarn-client" else "yarn-cluster"
+ val props = new Properties()
+
+ props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
+
+ val childClasspath = logConfDir.getAbsolutePath() +
+ File.pathSeparator +
+ sys.props("java.class.path") +
+ File.pathSeparator +
+ extraClassPath.mkString(File.pathSeparator)
+ props.setProperty("spark.driver.extraClassPath", childClasspath)
+ props.setProperty("spark.executor.extraClassPath", childClasspath)
+
+ // SPARK-4267: make sure java options are propagated correctly.
+ props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
+ props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
+
+ yarnCluster.getConfig().foreach { e =>
+ props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+ }
+
+ sys.props.foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
+ props.setProperty(k, v)
+ }
+ }
+
+ extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+ val propsFile = File.createTempFile("spark", ".properties", tempDir)
+ val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
+ props.store(writer, "Spark properties.")
+ writer.close()
+
+ val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
+ val mainArgs =
+ if (klass.endsWith(".py")) {
+ Seq(klass)
+ } else {
+ Seq("--class", klass, fakeSparkJar.getAbsolutePath())
+ }
+ val argv =
+ Seq(
+ new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
+ "--master", master,
+ "--num-executors", "1",
+ "--properties-file", propsFile.getAbsolutePath()) ++
+ extraJarArgs ++
+ sparkArgs ++
+ mainArgs ++
+ appArgs
+
+ Utils.executeAndGetOutput(argv,
+ extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath()))
+ }
+
/**
* This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
* any sort of error when the job process finishes successfully, but the job itself fails. So
* the tests enforce that something is written to a file after everything is ok to indicate
* that the job succeeded.
*/
- private def checkResult(result: File) = {
- var resultString = Files.toString(result, Charsets.UTF_8)
- resultString should be ("success")
+ private def checkResult(result: File): Unit = {
+ checkResult(result, "success")
+ }
+
+ private def checkResult(result: File, expected: String): Unit = {
+ var resultString = Files.toString(result, UTF_8)
+ resultString should be (expected)
+ }
+
+ private def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
}
}
@@ -229,22 +294,22 @@ private object YarnClusterDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
var listener: SaveExecutorInfo = null
- def main(args: Array[String]) = {
- if (args.length != 2) {
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
System.err.println(
s"""
|Invalid command line: ${args.mkString(" ")}
|
- |Usage: YarnClusterDriver [master] [result file]
+ |Usage: YarnClusterDriver [result file]
""".stripMargin)
System.exit(1)
}
listener = new SaveExecutorInfo
- val sc = new SparkContext(new SparkConf().setMaster(args(0))
+ val sc = new SparkContext(new SparkConf()
.setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns"))
sc.addSparkListener(listener)
- val status = new File(args(1))
+ val status = new File(args(0))
var result = "failure"
try {
val data = sc.parallelize(1 to 4, 4).collect().toSet
@@ -253,7 +318,48 @@ private object YarnClusterDriver extends Logging with Matchers {
result = "success"
} finally {
sc.stop()
- Files.write(result, status, Charsets.UTF_8)
+ Files.write(result, status, UTF_8)
+ }
+
+ // verify log urls are present
+ listener.addedExecutorInfos.values.foreach { info =>
+ assert(info.logUrlMap.nonEmpty)
+ }
+ }
+
+}
+
+private object YarnClasspathTest {
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClasspathTest [driver result file] [executor result file]
+ """.stripMargin)
+ System.exit(1)
+ }
+
+ readResource(args(0))
+ val sc = new SparkContext(new SparkConf())
+ try {
+ sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) }
+ } finally {
+ sc.stop()
+ }
+ }
+
+ private def readResource(resultPath: String): Unit = {
+ var result = "failure"
+ try {
+ val ccl = Thread.currentThread().getContextClassLoader()
+ val resource = ccl.getResourceAsStream("test.resource")
+ val bytes = ByteStreams.toByteArray(resource)
+ result = new String(bytes, 0, bytes.length, UTF_8)
+ } finally {
+ Files.write(result, new File(resultPath), UTF_8)
}
}
From a95ed52157473fb0e42e910ee15270e7f0edf943 Mon Sep 17 00:00:00 2001
From: Andrew Or
Date: Mon, 9 Feb 2015 21:18:48 -0800
Subject: [PATCH 030/817] [SPARK-5703] AllJobsPage throws empty.max exception
If you have a `SparkListenerJobEnd` event without the corresponding `SparkListenerJobStart` event, then `JobProgressListener` will create an empty `JobUIData` with an empty `stageIds` list. However, later in `AllJobsPage` we call `stageIds.max`. If this is empty, it will throw an exception.
This crashed my history server.
Author: Andrew Or
Closes #4490 from andrewor14/jobs-page-max and squashes the following commits:
21797d3 [Andrew Or] Check nonEmpty before calling max
---
.../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 045c69da06feb..bd923d78a86ce 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -42,7 +42,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
}
def makeRow(job: JobUIData): Seq[Node] = {
- val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageInfo = Option(job.stageIds)
+ .filter(_.nonEmpty)
+ .flatMap { ids => listener.stageIdToInfo.get(ids.max) }
val lastStageData = lastStageInfo.flatMap { s =>
listener.stageIdToData.get((s.stageId, s.attemptId))
}
From a2d33d0b01af87e931d9d883638a52d7a86f6248 Mon Sep 17 00:00:00 2001
From: Kay Ousterhout
Date: Mon, 9 Feb 2015 21:22:09 -0800
Subject: [PATCH 031/817] [SPARK-5701] Only set ShuffleReadMetrics when task
has shuffle deps
The updateShuffleReadMetrics method in TaskMetrics (called by the executor heartbeater) will currently always add a ShuffleReadMetrics to TaskMetrics (with values set to 0), even when the task didn't read any shuffle data. ShuffleReadMetrics should only be added if the task reads shuffle data.
Author: Kay Ousterhout
Closes #4488 from kayousterhout/SPARK-5701 and squashes the following commits:
673ed58 [Kay Ousterhout] SPARK-5701: Only set ShuffleReadMetrics when task has shuffle deps
---
.../apache/spark/executor/TaskMetrics.scala | 22 ++++++++-------
.../spark/executor/TaskMetricsSuite.scala | 28 +++++++++++++++++++
2 files changed, 40 insertions(+), 10 deletions(-)
create mode 100644 core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index d05659193b334..bf3f1e4fc7832 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -177,8 +177,8 @@ class TaskMetrics extends Serializable {
* Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
* we can store all the different inputMetrics (one per readMethod).
*/
- private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod):
- InputMetrics =synchronized {
+ private[spark] def getInputMetricsForReadMethod(
+ readMethod: DataReadMethod): InputMetrics = synchronized {
_inputMetrics match {
case None =>
val metrics = new InputMetrics(readMethod)
@@ -195,15 +195,17 @@ class TaskMetrics extends Serializable {
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics(): Unit = synchronized {
- val merged = new ShuffleReadMetrics()
- for (depMetrics <- depsShuffleReadMetrics) {
- merged.incFetchWaitTime(depMetrics.fetchWaitTime)
- merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
- merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
- merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
- merged.incRecordsRead(depMetrics.recordsRead)
+ if (!depsShuffleReadMetrics.isEmpty) {
+ val merged = new ShuffleReadMetrics()
+ for (depMetrics <- depsShuffleReadMetrics) {
+ merged.incFetchWaitTime(depMetrics.fetchWaitTime)
+ merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
+ merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
+ merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
+ merged.incRecordsRead(depMetrics.recordsRead)
+ }
+ _shuffleReadMetrics = Some(merged)
}
- _shuffleReadMetrics = Some(merged)
}
private[spark] def updateInputMetrics(): Unit = synchronized {
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
new file mode 100644
index 0000000000000..326e203afe136
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.scalatest.FunSuite
+
+class TaskMetricsSuite extends FunSuite {
+ test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") {
+ val taskMetrics = new TaskMetrics()
+ taskMetrics.updateShuffleReadMetrics()
+ assert(taskMetrics.shuffleReadMetrics.isEmpty)
+ }
+}
From bd0b5ea708aa5b84adb67c039ec52408289718bb Mon Sep 17 00:00:00 2001
From: Cheng Hao
Date: Mon, 9 Feb 2015 21:33:34 -0800
Subject: [PATCH 032/817] [SQL] Remove the duplicated code
Author: Cheng Hao
Closes #4494 from chenghao-intel/tiny_code_change and squashes the following commits:
450dfe7 [Cheng Hao] remove the duplicated code
---
.../org/apache/spark/sql/execution/SparkStrategies.scala | 5 -----
1 file changed, 5 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 81bcf5a6f32dd..edf8a5be64ff1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -342,11 +342,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
ExecutedCommand(
RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil
- case LogicalDescribeCommand(table, isExtended) =>
- val resultPlan = self.sqlContext.executePlan(table).executedPlan
- ExecutedCommand(
- RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil
-
case _ => Nil
}
}
From ef2f55b97f58fa06acb30e9e0172fb66fba383bc Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley"
Date: Mon, 9 Feb 2015 22:09:07 -0800
Subject: [PATCH 033/817] [SPARK-5597][MLLIB] save/load for decision trees and
emsembles
This is based on #4444 from jkbradley with the following changes:
1. Node schema updated to
~~~
treeId: int
nodeId: Int
predict/
|- predict: Double
|- prob: Double
impurity: Double
isLeaf: Boolean
split/
|- feature: Int
|- threshold: Double
|- featureType: Int
|- categories: Array[Double]
leftNodeId: Integer
rightNodeId: Integer
infoGain: Double
~~~
2. Some refactor of the implementation.
Closes #4444.
Author: Joseph K. Bradley
Author: Xiangrui Meng
Closes #4493 from mengxr/SPARK-5597 and squashes the following commits:
75e3bb6 [Xiangrui Meng] fix style
2b0033d [Xiangrui Meng] update tree export schema and refactor the implementation
45873a2 [Joseph K. Bradley] org imports
1d4c264 [Joseph K. Bradley] Added save/load for tree ensembles
dcdbf85 [Joseph K. Bradley] added save/load for decision tree but need to generalize it to ensembles
---
.../mllib/tree/model/DecisionTreeModel.scala | 197 +++++++++++++++++-
.../tree/model/InformationGainStats.scala | 4 +-
.../apache/spark/mllib/tree/model/Node.scala | 5 +
.../spark/mllib/tree/model/Predict.scala | 7 +
.../mllib/tree/model/treeEnsembleModels.scala | 157 +++++++++++++-
.../spark/mllib/tree/DecisionTreeSuite.scala | 120 ++++++++++-
.../tree/GradientBoostedTreesSuite.scala | 81 ++++---
.../spark/mllib/tree/RandomForestSuite.scala | 28 ++-
8 files changed, 561 insertions(+), 38 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a25e625a4017a..89ecf3773dd77 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,17 @@
package org.apache.spark.mllib.tree.model
+import scala.collection.mutable
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
/**
* :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
* @param algo algorithm type -- classification or regression
*/
@Experimental
-class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
/**
* Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
header + topNode.subtreeToString(2)
}
+ override def save(sc: SparkContext, path: String): Unit = {
+ DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object DecisionTreeModel extends Loader[DecisionTreeModel] {
+
+ private[tree] object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+
+ case class PredictData(predict: Double, prob: Double) {
+ def toPredict: Predict = new Predict(predict, prob)
+ }
+
+ object PredictData {
+ def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
+
+ def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1))
+ }
+
+ case class SplitData(
+ feature: Int,
+ threshold: Double,
+ featureType: Int,
+ categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+ def toSplit: Split = {
+ new Split(feature, threshold, FeatureType(featureType), categories.toList)
+ }
+ }
+
+ object SplitData {
+ def apply(s: Split): SplitData = {
+ SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
+ }
+
+ def apply(r: Row): SplitData = {
+ SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3))
+ }
+ }
+
+ /** Model data for model import/export */
+ case class NodeData(
+ treeId: Int,
+ nodeId: Int,
+ predict: PredictData,
+ impurity: Double,
+ isLeaf: Boolean,
+ split: Option[SplitData],
+ leftNodeId: Option[Int],
+ rightNodeId: Option[Int],
+ infoGain: Option[Double])
+
+ object NodeData {
+ def apply(treeId: Int, n: Node): NodeData = {
+ NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf,
+ n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
+ n.stats.map(_.gain))
+ }
+
+ def apply(r: Row): NodeData = {
+ val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5)))
+ val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6))
+ val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7))
+ val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8))
+ NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3),
+ r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain)
+ }
+ }
+
+ def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadataRDD = sc.parallelize(
+ Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
+ .toDataFrame("class", "version", "algo", "numNodes")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val nodes = model.topNode.subtreeIterator.toSeq
+ val dataRDD: DataFrame = sc.parallelize(nodes)
+ .map(NodeData.apply(0, _))
+ .toDataFrame
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(datapath)
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[NodeData](dataRDD.schema)
+ val nodes = dataRDD.map(NodeData.apply)
+ // Build node data into a tree.
+ val trees = constructTrees(nodes)
+ assert(trees.size == 1,
+ "Decision tree should contain exactly one tree but got ${trees.size} trees.")
+ val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
+ assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
+ s" Expected $numNodes nodes but found ${model.numNodes}")
+ model
+ }
+
+ def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
+ val trees = nodes
+ .groupBy(_.treeId)
+ .mapValues(_.toArray)
+ .collect()
+ .map { case (treeId, data) =>
+ (treeId, constructTree(data))
+ }.sortBy(_._1)
+ val numTrees = trees.size
+ val treeIndices = trees.map(_._1).toSeq
+ assert(treeIndices == (0 until numTrees),
+ s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
+ trees.map(_._2)
+ }
+
+ /**
+ * Given a list of nodes from a tree, construct the tree.
+ * @param data array of all node data in a tree.
+ */
+ def constructTree(data: Array[NodeData]): Node = {
+ val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
+ assert(dataMap.contains(1),
+ s"DecisionTree missing root node (id = 1).")
+ constructNode(1, dataMap, mutable.Map.empty)
+ }
+
+ /**
+ * Builds a node from the node data map and adds new nodes to the input nodes map.
+ */
+ private def constructNode(
+ id: Int,
+ dataMap: Map[Int, NodeData],
+ nodes: mutable.Map[Int, Node]): Node = {
+ if (nodes.contains(id)) {
+ return nodes(id)
+ }
+ val data = dataMap(id)
+ val node =
+ if (data.isLeaf) {
+ Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
+ } else {
+ val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
+ val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
+ val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
+ rightNode.impurity, leftNode.predict, rightNode.predict)
+ new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
+ data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
+ }
+ nodes += node.id -> node
+ node
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ val (algo: String, numNodes: Int) = try {
+ val algo_numNodes = metadata.select("algo", "numNodes").collect()
+ assert(algo_numNodes.length == 1)
+ algo_numNodes(0) match {
+ case Row(a: String, n: Int) => (a, n)
+ }
+ } catch {
+ // Catch both Error and Exception since the checks above can throw either.
+ case e: Throwable =>
+ throw new Exception(
+ s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
+ + s" Error message: ${e.getMessage}")
+ }
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path, algo, numNodes)
+ case _ => throw new Exception(
+ s"DecisionTreeModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 9a50ecb550c38..80990aa9a603f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -49,7 +49,9 @@ class InformationGainStats(
gain == other.gain &&
impurity == other.impurity &&
leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
}
case _ => false
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 2179da8dbe03e..d961081d185e9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -166,6 +166,11 @@ class Node (
}
}
+ /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
+ private[tree] def subtreeIterator: Iterator[Node] = {
+ Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
+ rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
+ }
}
private[tree] object Node {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 004838ee5ba0e..ad4c0dbbfb3e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -32,4 +32,11 @@ class Predict(
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
}
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case p: Predict => predict == p.predict && prob == p.prob
+ case _ => false
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 22997110de8dd..23bd46baabf65 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -21,12 +21,17 @@ import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* :: Experimental ::
@@ -38,9 +43,42 @@ import org.apache.spark.rdd.RDD
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
- combiningStrategy = if (algo == Classification) Vote else Average) {
+ combiningStrategy = if (algo == Classification) Vote else Average)
+ with Saveable {
require(trees.forall(_.algo == algo))
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ RandomForestModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object RandomForestModel extends Loader[RandomForestModel] {
+
+ override def load(sc: SparkContext, path: String): RandomForestModel = {
+ val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ assert(metadata.treeWeights.forall(_ == 1.0))
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new RandomForestModel(Algo.fromString(metadata.algo), trees)
+ case _ => throw new Exception(s"RandomForestModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ }
+
}
/**
@@ -56,9 +94,42 @@ class GradientBoostedTreesModel(
override val algo: Algo,
override val trees: Array[DecisionTreeModel],
override val treeWeights: Array[Double])
- extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
+ with Saveable {
require(trees.size == treeWeights.size)
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
+
+ override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
+ val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ assert(metadata.combiningStrategy == Sum.toString)
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights)
+ case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ }
+
}
/**
@@ -176,3 +247,85 @@ private[tree] sealed class TreeEnsembleModel(
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
+
+private[tree] object TreeEnsembleModel {
+
+ object SaveLoadV1_0 {
+
+ import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+
+ def thisFormatVersion = "1.0"
+
+ case class Metadata(
+ algo: String,
+ treeAlgo: String,
+ combiningStrategy: String,
+ treeWeights: Array[Double])
+
+ /**
+ * Model data for model import/export.
+ * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields
+ * of nested fields; once that is possible, we can use something like:
+ * case class EnsembleNodeData(treeId: Int, node: NodeData),
+ * where NodeData is from DecisionTreeModel.
+ */
+ case class EnsembleNodeData(treeId: Int, node: NodeData)
+
+ def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
+ model.combiningStrategy.toString, model.treeWeights)
+ val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, metadata)), 1)
+ .toDataFrame("class", "version", "metadata")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
+ tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
+ }.toDataFrame
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Read metadata from the loaded metadata DataFrame.
+ * @param path Path for loading data, used for debug messages.
+ */
+ def readMetadata(metadata: DataFrame, path: String): Metadata = {
+ try {
+ // We rely on the try-catch for schema checking rather than creating a schema just for this.
+ val metadataArray = metadata.select("metadata.algo", "metadata.treeAlgo",
+ "metadata.combiningStrategy", "metadata.treeWeights").collect()
+ assert(metadataArray.size == 1)
+ Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1),
+ metadataArray(0).getString(2), metadataArray(0).getAs[Seq[Double]](3).toArray)
+ } catch {
+ // Catch both Error and Exception since the checks above can throw either.
+ case e: Throwable =>
+ throw new Exception(
+ s"Unable to load TreeEnsembleModel metadata from: ${Loader.metadataPath(path)}."
+ + s" Error message: ${e.getMessage}")
+ }
+ }
+
+ /**
+ * Load trees for an ensemble, and return them in order.
+ * @param path path to load the model from
+ * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's
+ * algorithm).
+ */
+ def loadTrees(
+ sc: SparkContext,
+ path: String,
+ treeAlgo: String): Array[DecisionTreeModel] = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
+ val trees = constructTrees(nodes)
+ trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 9347eaf9221a8..7b1aed5ffeb3e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
@@ -857,9 +859,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
+
+ test("Node.subtreeIterator") {
+ val model = DecisionTreeSuite.createModel(Classification)
+ val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted
+ assert(nodeIds === DecisionTreeSuite.createdModelNodeIds)
+ }
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = DecisionTreeSuite.createModel(algo)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = DecisionTreeModel.load(sc, path)
+ DecisionTreeSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object DecisionTreeSuite {
+object DecisionTreeSuite extends FunSuite {
def validateClassifier(
model: DecisionTreeModel,
@@ -979,4 +1004,95 @@ object DecisionTreeSuite {
arr
}
+ /** Create a leaf node with the given node ID */
+ private def createLeafNode(id: Int): Node = {
+ Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true)
+ }
+
+ /**
+ * Create an internal node with the given node ID and feature type.
+ * Note: This does NOT set the child nodes.
+ */
+ private def createInternalNode(id: Int, featureType: FeatureType): Node = {
+ val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false)
+ featureType match {
+ case Continuous =>
+ node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous,
+ categories = List.empty[Double]))
+ case Categorical =>
+ node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
+ categories = List(0.0, 1.0)))
+ }
+ // TODO: The information gain stats should be consistent with the same info stored in children.
+ node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
+ leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
+ node
+ }
+
+ /**
+ * Create a tree model. This is deterministic and contains a variety of node and feature types.
+ */
+ private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ val topNode = createInternalNode(id = 1, Continuous)
+ val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
+ val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
+ topNode.leftNode = Some(node2)
+ topNode.rightNode = Some(node3)
+ node3.leftNode = Some(node6)
+ node3.rightNode = Some(node7)
+ new DecisionTreeModel(topNode, algo)
+ }
+
+ /** Sorted Node IDs matching the model returned by [[createModel()]] */
+ private val createdModelNodeIds = Array(1, 2, 3, 6, 7)
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ assert(a.algo === b.algo)
+ checkEqual(a.topNode, b.topNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendents are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.id === b.id)
+ assert(a.predict === b.predict)
+ assert(a.impurity === b.impurity)
+ assert(a.isLeaf === b.isLeaf)
+ assert(a.split === b.split)
+ (a.stats, b.stats) match {
+ // TODO: Check other fields besides the infomation gain.
+ case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
+ case (None, None) =>
+ case _ => throw new AssertionError(
+ s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+ }
+ (a.leftNode, b.leftNode) match {
+ case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has leftNode defined. " +
+ s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+ }
+ (a.rightNode, b.rightNode) match {
+ case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has rightNode defined. " +
+ s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index e8341a5d0d104..bde47606eb001 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
-
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[GradientBoostedTrees]].
@@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
- GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
- val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
-
- val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
- val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
- assert(gbt.trees.size === numIterations)
- try {
- EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
- } catch {
- case e: java.lang.AssertionError =>
- println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
- s" subsamplingRate=$subsamplingRate")
- throw e
- }
-
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- val dt = DecisionTree.train(remappedInput, treeStrategy)
-
- // Make sure trees are the same.
- assert(gbt.trees.head.toString == dt.toString)
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ assert(gbt.trees.size === numIterations)
+ try {
+ EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+ } catch {
+ case e: java.lang.AssertionError =>
+ println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+ s" subsamplingRate=$subsamplingRate")
+ throw e
}
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ // Make sure trees are the same.
+ assert(gbt.trees.head.toString == dt.toString)
}
}
@@ -133,14 +133,37 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
BoostingStrategy.defaultParams(algo)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = new GradientBoostedTreesModel(algo, trees, treeWeights)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = GradientBoostedTreesModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ assert(model.treeWeights === sameModel.treeWeights)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object GradientBoostedTreesSuite {
+private object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
- val randomSeeds = Array(681283, 4398)
-
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 55e963977b54f..ee3bc98486862 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.Node
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[RandomForest]].
@@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
assert(rf1.toDebugString != rf2.toDebugString)
}
-}
-
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray
+ val model = new RandomForestModel(algo, trees)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RandomForestModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}
From c15134632e74e3dee05eda20c6ef79915e15d02e Mon Sep 17 00:00:00 2001
From: Tathagata Das
Date: Mon, 9 Feb 2015 22:45:48 -0800
Subject: [PATCH 034/817] [SPARK-4964][Streaming][Kafka] More updates to
Exactly-once Kafka stream
Changes
- Added example
- Added a critical unit test that verifies that offset ranges can be recovered through checkpoints
Might add more changes.
Author: Tathagata Das
Closes #4384 from tdas/new-kafka-fixes and squashes the following commits:
7c931c3 [Tathagata Das] Small update
3ed9284 [Tathagata Das] updated scala doc
83d0402 [Tathagata Das] Added JavaDirectKafkaWordCount example.
26df23c [Tathagata Das] Updates based on PR comments from Cody
e4abf69 [Tathagata Das] Scala doc improvements and stuff.
bb65232 [Tathagata Das] Fixed test bug and refactored KafkaStreamSuite
50f2b56 [Tathagata Das] Added Java API and added more Scala and Java unit tests. Also updated docs.
e73589c [Tathagata Das] Minor changes.
4986784 [Tathagata Das] Added unit test to kafka offset recovery
6a91cab [Tathagata Das] Added example
---
.../streaming/JavaDirectKafkaWordCount.java | 113 ++++++
.../streaming/DirectKafkaWordCount.scala | 71 ++++
.../kafka/DirectKafkaInputDStream.scala | 5 +-
.../spark/streaming/kafka/KafkaCluster.scala | 3 +
.../spark/streaming/kafka/KafkaRDD.scala | 12 +-
.../streaming/kafka/KafkaRDDPartition.scala | 23 +-
.../spark/streaming/kafka/KafkaUtils.scala | 353 ++++++++++++++----
.../apache/spark/streaming/kafka/Leader.scala | 21 +-
.../spark/streaming/kafka/OffsetRange.scala | 53 ++-
.../kafka/JavaDirectKafkaStreamSuite.java | 159 ++++++++
.../streaming/kafka/JavaKafkaStreamSuite.java | 5 +-
.../kafka/DirectKafkaStreamSuite.scala | 302 +++++++++++++++
.../streaming/kafka/KafkaClusterSuite.scala | 24 +-
.../kafka/KafkaDirectStreamSuite.scala | 92 -----
.../spark/streaming/kafka/KafkaRDDSuite.scala | 8 +-
.../streaming/kafka/KafkaStreamSuite.scala | 62 +--
.../kafka/ReliableKafkaStreamSuite.scala | 4 +-
17 files changed, 1048 insertions(+), 262 deletions(-)
create mode 100644 examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
create mode 100644 examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
create mode 100644 external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
create mode 100644 external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
delete mode 100644 external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala
diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
new file mode 100644
index 0000000000000..bab9f2478e779
--- /dev/null
+++ b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import kafka.serializer.StringDecoder;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.streaming.api.java.*;
+import org.apache.spark.streaming.kafka.KafkaUtils;
+import org.apache.spark.streaming.Durations;
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: DirectKafkaWordCount
+ * is a list of one or more Kafka brokers
+ * is a list of one or more kafka topics to consume from
+ *
+ * Example:
+ * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2
+ */
+
+public final class JavaDirectKafkaWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.err.println("Usage: DirectKafkaWordCount \n" +
+ " is a list of one or more Kafka brokers\n" +
+ " is a list of one or more kafka topics to consume from\n\n");
+ System.exit(1);
+ }
+
+ StreamingExamples.setStreamingLogLevels();
+
+ String brokers = args[0];
+ String topics = args[1];
+
+ // Create context with 2 second batch interval
+ SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount");
+ JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2));
+
+ HashSet topicsSet = new HashSet(Arrays.asList(topics.split(",")));
+ HashMap kafkaParams = new HashMap();
+ kafkaParams.put("metadata.broker.list", brokers);
+
+ // Create direct kafka stream with brokers and topics
+ JavaPairInputDStream messages = KafkaUtils.createDirectStream(
+ jssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topicsSet
+ );
+
+ // Get the lines, split them into words, count the words and print
+ JavaDStream lines = messages.map(new Function, String>() {
+ @Override
+ public String call(Tuple2 tuple2) {
+ return tuple2._2();
+ }
+ });
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+ JavaPairDStream wordCounts = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ }).reduceByKey(
+ new Function2() {
+ @Override
+ public Integer call(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ });
+ wordCounts.print();
+
+ // Start the computation
+ jssc.start();
+ jssc.awaitTermination();
+ }
+}
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
new file mode 100644
index 0000000000000..deb08fd57b8c7
--- /dev/null
+++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming
+
+import kafka.serializer.StringDecoder
+
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.kafka._
+import org.apache.spark.SparkConf
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: DirectKafkaWordCount
+ * is a list of one or more Kafka brokers
+ * is a list of one or more kafka topics to consume from
+ *
+ * Example:
+ * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2
+ */
+object DirectKafkaWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println(s"""
+ |Usage: DirectKafkaWordCount
+ | is a list of one or more Kafka brokers
+ | is a list of one or more kafka topics to consume from
+ |
+ """".stripMargin)
+ System.exit(1)
+ }
+
+ StreamingExamples.setStreamingLogLevels()
+
+ val Array(brokers, topics) = args
+
+ // Create context with 2 second batch interval
+ val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
+ val ssc = new StreamingContext(sparkConf, Seconds(2))
+
+ // Create direct kafka stream with brokers and topics
+ val topicsSet = topics.split(",").toSet
+ val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
+ val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, topicsSet)
+
+ // Get the lines, split them into words, count the words and print
+ val lines = messages.map(_._2)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
+ wordCounts.print()
+
+ // Start the computation
+ ssc.start()
+ ssc.awaitTermination()
+ }
+}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index c7bca43eb889d..04e65cb3d708c 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -50,14 +50,13 @@ import org.apache.spark.streaming.dstream._
* @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive)
* starting point of the stream
* @param messageHandler function for translating each message into the desired type
- * @param maxRetries maximum number of times in a row to retry getting leaders' offsets
*/
private[streaming]
class DirectKafkaInputDStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
+ U <: Decoder[K]: ClassTag,
+ T <: Decoder[V]: ClassTag,
R: ClassTag](
@transient ssc_ : StreamingContext,
val kafkaParams: Map[String, String],
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index ccc62bfe8f057..2f7e0ab39fefd 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -332,6 +332,9 @@ object KafkaCluster {
extends ConsumerConfig(originalProps) {
val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp =>
val hpa = hp.split(":")
+ if (hpa.size == 1) {
+ throw new SparkException(s"Broker not the in correct format of : [$brokers]")
+ }
(hpa(0), hpa(1).toInt)
}
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index 50bf7cbdb8dbf..d56cc01be9514 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -36,14 +36,12 @@ import kafka.utils.VerifiableProperties
* Starting and ending offsets are specified in advance,
* so that you can control exactly-once semantics.
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * @param batch Each KafkaRDDPartition in the batch corresponds to a
- * range of offsets for a given Kafka topic/partition
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD
* @param messageHandler function for translating each message into the desired type
*/
-private[spark]
+private[kafka]
class KafkaRDD[
K: ClassTag,
V: ClassTag,
@@ -183,7 +181,7 @@ class KafkaRDD[
}
}
-private[spark]
+private[kafka]
object KafkaRDD {
import KafkaCluster.LeaderOffset
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
index 36372e08f65f6..a842a6f17766f 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Partition
* @param host preferred kafka host, i.e. the leader at the time the rdd was created
* @param port preferred kafka host's port
*/
-private[spark]
+private[kafka]
class KafkaRDDPartition(
val index: Int,
val topic: String,
@@ -36,24 +36,3 @@ class KafkaRDDPartition(
val host: String,
val port: Int
) extends Partition
-
-private[spark]
-object KafkaRDDPartition {
- def apply(
- index: Int,
- topic: String,
- partition: Int,
- fromOffset: Long,
- untilOffset: Long,
- host: String,
- port: Int
- ): KafkaRDDPartition = new KafkaRDDPartition(
- index,
- topic,
- partition,
- fromOffset,
- untilOffset,
- host,
- port
- )
-}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index f8aa6c5c6263c..7a2c3abdcc24b 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -18,7 +18,9 @@
package org.apache.spark.streaming.kafka
import java.lang.{Integer => JInt}
+import java.lang.{Long => JLong}
import java.util.{Map => JMap}
+import java.util.{Set => JSet}
import scala.reflect.ClassTag
import scala.collection.JavaConversions._
@@ -27,18 +29,19 @@ import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.{Decoder, StringDecoder}
-
+import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext}
+import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
+import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
object KafkaUtils {
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param ssc StreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..)
* @param groupId The group id for this consumer
@@ -62,7 +65,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param ssc StreamingContext object
* @param kafkaParams Map of kafka configuration parameters,
* see http://kafka.apache.org/08/configuration.html
@@ -81,7 +84,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..)
@@ -99,7 +102,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param jssc JavaStreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
@@ -119,10 +122,10 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param jssc JavaStreamingContext object
- * @param keyTypeClass Key type of RDD
- * @param valueTypeClass value type of RDD
+ * @param keyTypeClass Key type of DStream
+ * @param valueTypeClass value type of Dstream
* @param keyDecoderClass Type of kafka key decoder
* @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration parameters,
@@ -151,14 +154,14 @@ object KafkaUtils {
jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel)
}
- /** A batch-oriented interface for consuming from Kafka.
- * Starting and ending offsets are specified in advance,
- * so that you can control exactly-once semantics.
+ /**
+ * Create a RDD from Kafka using offset ranges for each topic and partition.
+ *
* @param sc SparkContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
*/
@@ -166,12 +169,12 @@ object KafkaUtils {
def createRDD[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag](
sc: SparkContext,
kafkaParams: Map[String, String],
offsetRanges: Array[OffsetRange]
- ): RDD[(K, V)] = {
+ ): RDD[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val kc = new KafkaCluster(kafkaParams)
val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet
@@ -179,121 +182,196 @@ object KafkaUtils {
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
- new KafkaRDD[K, V, U, T, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
+ new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
}
- /** A batch-oriented interface for consuming from Kafka.
- * Starting and ending offsets are specified in advance,
- * so that you can control exactly-once semantics.
+ /**
+ * :: Experimental ::
+ * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you
+ * specify the Kafka leader to connect to (to optimize fetching) and access the message as well
+ * as the metadata.
+ *
* @param sc SparkContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
* @param leaders Kafka leaders for each offset range in batch
- * @param messageHandler function for translating each message into the desired type
+ * @param messageHandler Function for translating each message and metadata into the desired type
*/
@Experimental
def createRDD[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
- R: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag,
+ R: ClassTag](
sc: SparkContext,
kafkaParams: Map[String, String],
offsetRanges: Array[OffsetRange],
leaders: Array[Leader],
messageHandler: MessageAndMetadata[K, V] => R
- ): RDD[R] = {
-
+ ): RDD[R] = {
val leaderMap = leaders
.map(l => TopicAndPartition(l.topic, l.partition) -> (l.host, l.port))
.toMap
- new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler)
+ new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler)
}
+
/**
- * This stream can guarantee that each message from Kafka is included in transformations
- * (as opposed to output actions) exactly once, even in most failure situations.
+ * Create a RDD from Kafka using offset ranges for each topic and partition.
*
- * Points to note:
- *
- * Failure Recovery - You must checkpoint this stream, or save offsets yourself and provide them
- * as the fromOffsets parameter on restart.
- * Kafka must have sufficient log retention to obtain messages after failure.
- *
- * Getting offsets from the stream - see programming guide
+ * @param jsc JavaSparkContext object
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param offsetRanges Each OffsetRange in the batch corresponds to a
+ * range of offsets for a given Kafka topic/partition
+ */
+ @Experimental
+ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]](
+ jsc: JavaSparkContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ kafkaParams: JMap[String, String],
+ offsetRanges: Array[OffsetRange]
+ ): JavaPairRDD[K, V] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ new JavaPairRDD(createRDD[K, V, KD, VD](
+ jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges))
+ }
+
+ /**
+ * :: Experimental ::
+ * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you
+ * specify the Kafka leader to connect to (to optimize fetching) and access the message as well
+ * as the metadata.
*
-. * Zookeeper - This does not use Zookeeper to store offsets. For interop with Kafka monitors
- * that depend on Zookeeper, you must store offsets in ZK yourself.
+ * @param jsc JavaSparkContext object
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param offsetRanges Each OffsetRange in the batch corresponds to a
+ * range of offsets for a given Kafka topic/partition
+ * @param leaders Kafka leaders for each offset range in batch
+ * @param messageHandler Function for translating each message and metadata into the desired type
+ */
+ @Experimental
+ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ jsc: JavaSparkContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ recordClass: Class[R],
+ kafkaParams: JMap[String, String],
+ offsetRanges: Array[OffsetRange],
+ leaders: Array[Leader],
+ messageHandler: JFunction[MessageAndMetadata[K, V], R]
+ ): JavaRDD[R] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ implicit val recordCmt: ClassTag[R] = ClassTag(recordClass)
+ createRDD[K, V, KD, VD, R](
+ jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaders, messageHandler.call _)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
*
- * End-to-end semantics - This does not guarantee that any output operation will push each record
- * exactly once. To ensure end-to-end exactly-once semantics (that is, receiving exactly once and
- * outputting exactly once), you have to either ensure that the output operation is
- * idempotent, or transactionally store offsets with the output. See the programming guide for
- * more details.
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
*
* @param ssc StreamingContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * @param messageHandler function for translating each message into the desired type
- * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive)
- * starting point of the stream
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
+ * starting point of the stream
+ * @param messageHandler Function for translating each message and metadata into the desired type
*/
@Experimental
def createDirectStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag,
R: ClassTag] (
ssc: StreamingContext,
kafkaParams: Map[String, String],
fromOffsets: Map[TopicAndPartition, Long],
messageHandler: MessageAndMetadata[K, V] => R
): InputDStream[R] = {
- new DirectKafkaInputDStream[K, V, U, T, R](
+ new DirectKafkaInputDStream[K, V, KD, VD, R](
ssc, kafkaParams, fromOffsets, messageHandler)
}
/**
- * This stream can guarantee that each message from Kafka is included in transformations
- * (as opposed to output actions) exactly once, even in most failure situations.
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
*
* Points to note:
- *
- * Failure Recovery - You must checkpoint this stream.
- * Kafka must have sufficient log retention to obtain messages after failure.
- *
- * Getting offsets from the stream - see programming guide
- *
-. * Zookeeper - This does not use Zookeeper to store offsets. For interop with Kafka monitors
- * that depend on Zookeeper, you must store offsets in ZK yourself.
- *
- * End-to-end semantics - This does not guarantee that any output operation will push each record
- * exactly once. To ensure end-to-end exactly-once semantics (that is, receiving exactly once and
- * outputting exactly once), you have to ensure that the output operation is idempotent.
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
*
* @param ssc StreamingContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * If starting without a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
* to determine where the stream starts (defaults to "largest")
- * @param topics names of the topics to consume
+ * @param topics Names of the topics to consume
*/
@Experimental
def createDirectStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag] (
ssc: StreamingContext,
kafkaParams: Map[String, String],
topics: Set[String]
@@ -313,11 +391,128 @@ object KafkaUtils {
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
- new DirectKafkaInputDStream[K, V, U, T, (K, V)](
+ new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}).fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
}
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
+ *
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
+ *
+ * @param jssc JavaStreamingContext object
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param keyDecoderClass Class of the key decoder
+ * @param valueDecoderClass Class of the value decoder
+ * @param recordClass Class of the records in DStream
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
+ * starting point of the stream
+ * @param messageHandler Function for translating each message and metadata into the desired type
+ */
+ @Experimental
+ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ jssc: JavaStreamingContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ recordClass: Class[R],
+ kafkaParams: JMap[String, String],
+ fromOffsets: JMap[TopicAndPartition, JLong],
+ messageHandler: JFunction[MessageAndMetadata[K, V], R]
+ ): JavaInputDStream[R] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ implicit val recordCmt: ClassTag[R] = ClassTag(recordClass)
+ createDirectStream[K, V, KD, VD, R](
+ jssc.ssc,
+ Map(kafkaParams.toSeq: _*),
+ Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*),
+ messageHandler.call _
+ )
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
+ *
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
+ *
+ * @param jssc JavaStreamingContext object
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param keyDecoderClass Class of the key decoder
+ * @param valueDecoderClass Class type of the value decoder
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
+ * to determine where the stream starts (defaults to "largest")
+ * @param topics Names of the topics to consume
+ */
+ @Experimental
+ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ jssc: JavaStreamingContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ kafkaParams: JMap[String, String],
+ topics: JSet[String]
+ ): JavaPairInputDStream[K, V] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ createDirectStream[K, V, KD, VD](
+ jssc.ssc,
+ Map(kafkaParams.toSeq: _*),
+ Set(topics.toSeq: _*)
+ )
+ }
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala
index 3454d92e72b47..c129a26836c0d 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala
@@ -19,17 +19,28 @@ package org.apache.spark.streaming.kafka
import kafka.common.TopicAndPartition
-/** Host info for the leader of a Kafka TopicAndPartition */
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Represent the host info for the leader of a Kafka partition.
+ */
+@Experimental
final class Leader private(
- /** kafka topic name */
+ /** Kafka topic name */
val topic: String,
- /** kafka partition id */
+ /** Kafka partition id */
val partition: Int,
- /** kafka hostname */
+ /** Leader's hostname */
val host: String,
- /** kafka host's port */
+ /** Leader's port */
val port: Int) extends Serializable
+/**
+ * :: Experimental ::
+ * Companion object the provides methods to create instances of [[Leader]].
+ */
+@Experimental
object Leader {
def create(topic: String, partition: Int, host: String, port: Int): Leader =
new Leader(topic, partition, host, port)
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
index 334c12e4627b4..9c3dfeb8f5928 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
@@ -19,16 +19,35 @@ package org.apache.spark.streaming.kafka
import kafka.common.TopicAndPartition
-/** Something that has a collection of OffsetRanges */
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the
+ * offset ranges in RDDs generated by the direct Kafka DStream (see
+ * [[KafkaUtils.createDirectStream()]]).
+ * {{{
+ * KafkaUtils.createDirectStream(...).foreachRDD { rdd =>
+ * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ * ...
+ * }
+ * }}}
+ */
+@Experimental
trait HasOffsetRanges {
def offsetRanges: Array[OffsetRange]
}
-/** Represents a range of offsets from a single Kafka TopicAndPartition */
+/**
+ * :: Experimental ::
+ * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class
+ * can be created with `OffsetRange.create()`.
+ */
+@Experimental
final class OffsetRange private(
- /** kafka topic name */
+ /** Kafka topic name */
val topic: String,
- /** kafka partition id */
+ /** Kafka partition id */
val partition: Int,
/** inclusive starting offset */
val fromOffset: Long,
@@ -36,11 +55,33 @@ final class OffsetRange private(
val untilOffset: Long) extends Serializable {
import OffsetRange.OffsetRangeTuple
+ override def equals(obj: Any): Boolean = obj match {
+ case that: OffsetRange =>
+ this.topic == that.topic &&
+ this.partition == that.partition &&
+ this.fromOffset == that.fromOffset &&
+ this.untilOffset == that.untilOffset
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ toTuple.hashCode()
+ }
+
+ override def toString(): String = {
+ s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]"
+ }
+
/** this is to avoid ClassNotFoundException during checkpoint restore */
private[streaming]
def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset)
}
+/**
+ * :: Experimental ::
+ * Companion object the provides methods to create instances of [[OffsetRange]].
+ */
+@Experimental
object OffsetRange {
def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange =
new OffsetRange(topic, partition, fromOffset, untilOffset)
@@ -61,10 +102,10 @@ object OffsetRange {
new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset)
/** this is to avoid ClassNotFoundException during checkpoint restore */
- private[spark]
+ private[kafka]
type OffsetRangeTuple = (String, Int, Long, Long)
- private[streaming]
+ private[kafka]
def apply(t: OffsetRangeTuple) =
new OffsetRange(t._1, t._2, t._3, t._4)
}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
new file mode 100644
index 0000000000000..1334cc8fd1b57
--- /dev/null
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+
+import scala.Tuple2;
+
+import junit.framework.Assert;
+
+import kafka.common.TopicAndPartition;
+import kafka.message.MessageAndMetadata;
+import kafka.serializer.StringDecoder;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+import org.junit.Test;
+import org.junit.After;
+import org.junit.Before;
+
+public class JavaDirectKafkaStreamSuite implements Serializable {
+ private transient JavaStreamingContext ssc = null;
+ private transient Random random = new Random();
+ private transient KafkaStreamSuiteBase suiteBase = null;
+
+ @Before
+ public void setUp() {
+ suiteBase = new KafkaStreamSuiteBase() { };
+ suiteBase.setupKafka();
+ System.clearProperty("spark.driver.port");
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ System.clearProperty("spark.driver.port");
+ suiteBase.tearDownKafka();
+ }
+
+ @Test
+ public void testKafkaStream() throws InterruptedException {
+ String topic1 = "topic1";
+ String topic2 = "topic2";
+
+ String[] topic1data = createTopicAndSendData(topic1);
+ String[] topic2data = createTopicAndSendData(topic2);
+
+ HashSet sent = new HashSet();
+ sent.addAll(Arrays.asList(topic1data));
+ sent.addAll(Arrays.asList(topic2data));
+
+ HashMap kafkaParams = new HashMap();
+ kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress());
+ kafkaParams.put("auto.offset.reset", "smallest");
+
+ JavaDStream stream1 = KafkaUtils.createDirectStream(
+ ssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topicToSet(topic1)
+ ).map(
+ new Function, String>() {
+ @Override
+ public String call(scala.Tuple2 kv) throws Exception {
+ return kv._2();
+ }
+ }
+ );
+
+ JavaDStream stream2 = KafkaUtils.createDirectStream(
+ ssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ String.class,
+ kafkaParams,
+ topicOffsetToMap(topic2, (long) 0),
+ new Function, String>() {
+ @Override
+ public String call(MessageAndMetadata msgAndMd) throws Exception {
+ return msgAndMd.message();
+ }
+ }
+ );
+ JavaDStream unifiedStream = stream1.union(stream2);
+
+ final HashSet result = new HashSet();
+ unifiedStream.foreachRDD(
+ new Function, Void>() {
+ @Override
+ public Void call(org.apache.spark.api.java.JavaRDD rdd) throws Exception {
+ result.addAll(rdd.collect());
+ return null;
+ }
+ }
+ );
+ ssc.start();
+ long startTime = System.currentTimeMillis();
+ boolean matches = false;
+ while (!matches && System.currentTimeMillis() - startTime < 20000) {
+ matches = sent.size() == result.size();
+ Thread.sleep(50);
+ }
+ Assert.assertEquals(sent, result);
+ ssc.stop();
+ }
+
+ private HashSet topicToSet(String topic) {
+ HashSet topicSet = new HashSet();
+ topicSet.add(topic);
+ return topicSet;
+ }
+
+ private HashMap topicOffsetToMap(String topic, Long offsetToStart) {
+ HashMap topicMap = new HashMap();
+ topicMap.put(new TopicAndPartition(topic, 0), offsetToStart);
+ return topicMap;
+ }
+
+ private String[] createTopicAndSendData(String topic) {
+ String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+ suiteBase.createTopic(topic);
+ suiteBase.sendMessages(topic, data);
+ return data;
+ }
+}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 6e1abf3f385ee..208cc51b29876 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -79,9 +79,10 @@ public void testKafkaStream() throws InterruptedException {
suiteBase.createTopic(topic);
HashMap tmp = new HashMap(sent);
- suiteBase.produceAndSendMessage(topic,
+ suiteBase.sendMessages(topic,
JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
- Predef.>conforms()));
+ Predef.>conforms())
+ );
HashMap kafkaParams = new HashMap();
kafkaParams.put("zookeeper.connect", suiteBase.zkAddress());
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
new file mode 100644
index 0000000000000..b25c2120d54f7
--- /dev/null
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import kafka.serializer.StringDecoder
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.concurrent.{Eventually, Timeouts}
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
+import org.apache.spark.streaming.dstream.{DStream, InputDStream}
+import org.apache.spark.util.Utils
+import kafka.common.TopicAndPartition
+import kafka.message.MessageAndMetadata
+
+class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
+ with BeforeAndAfter with BeforeAndAfterAll with Eventually {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(this.getClass.getSimpleName)
+
+ var sc: SparkContext = _
+ var ssc: StreamingContext = _
+ var testDir: File = _
+
+ override def beforeAll {
+ setupKafka()
+ }
+
+ override def afterAll {
+ tearDownKafka()
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ sc = null
+ }
+ if (sc != null) {
+ sc.stop()
+ }
+ if (testDir != null) {
+ Utils.deleteRecursively(testDir)
+ }
+ }
+
+
+ test("basic stream receiving with multiple topics and smallest starting offset") {
+ val topics = Set("basic1", "basic2", "basic3")
+ val data = Map("a" -> 7, "b" -> 9)
+ topics.foreach { t =>
+ createTopic(t)
+ sendMessages(t, data)
+ }
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "smallest"
+ )
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, topics)
+ }
+ var total = 0L
+
+ stream.foreachRDD { rdd =>
+ // Get the offset ranges in the RDD
+ val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
+ // For each partition, get size of the range in the partition,
+ // and the number of items in the partition
+ val off = offsets(i)
+ val all = iter.toSeq
+ val partSize = all.size
+ val rangeSize = off.untilOffset - off.fromOffset
+ Iterator((partSize, rangeSize))
+ }.collect
+
+ // Verify whether number of elements in each partition
+ // matches with the corresponding offset range
+ collected.foreach { case (partSize, rangeSize) =>
+ assert(partSize === rangeSize, "offset ranges are wrong")
+ }
+ total += collected.size // Add up all the collected items
+ }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(total === data.values.sum * topics.size, "didn't get all messages")
+ }
+ ssc.stop()
+ }
+
+ test("receiving from largest starting offset") {
+ val topic = "largest"
+ val topicPartition = TopicAndPartition(topic, 0)
+ val data = Map("a" -> 10)
+ createTopic(topic)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "largest"
+ )
+ val kc = new KafkaCluster(kafkaParams)
+ def getLatestOffset(): Long = {
+ kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset
+ }
+
+ // Send some initial messages before starting context
+ sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() > 3)
+ }
+ val offsetBeforeStart = getLatestOffset()
+
+ // Setup context and kafka stream with largest offset
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Set(topic))
+ }
+ assert(
+ stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]]
+ .fromOffsets(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+
+ val collectedData = new mutable.ArrayBuffer[String]()
+ stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+
+ test("creating stream by offset") {
+ val topic = "offset"
+ val topicPartition = TopicAndPartition(topic, 0)
+ val data = Map("a" -> 10)
+ createTopic(topic)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "largest"
+ )
+ val kc = new KafkaCluster(kafkaParams)
+ def getLatestOffset(): Long = {
+ kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset
+ }
+
+ // Send some initial messages before starting context
+ sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() >= 10)
+ }
+ val offsetBeforeStart = getLatestOffset()
+
+ // Setup context and kafka stream with largest offset
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String](
+ ssc, kafkaParams, Map(topicPartition -> 11L),
+ (m: MessageAndMetadata[String, String]) => m.message())
+ }
+ assert(
+ stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]]
+ .fromOffsets(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+
+ val collectedData = new mutable.ArrayBuffer[String]()
+ stream.foreachRDD { rdd => collectedData ++= rdd.collect() }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+ // Test to verify the offset ranges can be recovered from the checkpoints
+ test("offset recovery") {
+ val topic = "recovery"
+ createTopic(topic)
+ testDir = Utils.createTempDir()
+
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "smallest"
+ )
+
+ // Send data to Kafka and wait for it to be received
+ def sendDataAndWaitForReceive(data: Seq[Int]) {
+ val strings = data.map { _.toString}
+ sendMessages(topic, strings.map { _ -> 1}.toMap)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains })
+ }
+ }
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(100))
+ val kafkaStream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Set(topic))
+ }
+ val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt }
+ val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) =>
+ Some(values.sum + state.getOrElse(0))
+ }
+ ssc.checkpoint(testDir.getAbsolutePath)
+
+ // This is to collect the raw data received from Kafka
+ kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+ val data = rdd.map { _._2 }.collect()
+ DirectKafkaStreamSuite.collectedData.appendAll(data)
+ }
+
+ // This is ensure all the data is eventually receiving only once
+ stateStream.foreachRDD { (rdd: RDD[(String, Int)]) =>
+ rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 }
+ }
+ ssc.start()
+
+ // Send some data and wait for them to be received
+ for (i <- (1 to 10).grouped(4)) {
+ sendDataAndWaitForReceive(i)
+ }
+
+ // Verify that offset ranges were generated
+ val offsetRangesBeforeStop = getOffsetRanges(kafkaStream)
+ assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated")
+ assert(
+ offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 },
+ "starting offset not zero"
+ )
+ ssc.stop()
+ logInfo("====== RESTARTING ========")
+
+ // Recover context from checkpoints
+ ssc = new StreamingContext(testDir.getAbsolutePath)
+ val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]]
+
+ // Verify offset ranges have been recovered
+ val recoveredOffsetRanges = getOffsetRanges(recoveredStream)
+ assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
+ val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) }
+ assert(
+ recoveredOffsetRanges.forall { or =>
+ earlierOffsetRangesAsSets.contains((or._1, or._2.toSet))
+ },
+ "Recovered ranges are not the same as the ones generated"
+ )
+
+ // Restart context, give more data and verify the total at the end
+ // If the total is write that means each records has been received only once
+ ssc.start()
+ sendDataAndWaitForReceive(11 to 20)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(DirectKafkaStreamSuite.total === (1 to 20).sum)
+ }
+ ssc.stop()
+ }
+
+ /** Get the generated offset ranges from the DirectKafkaStream */
+ private def getOffsetRanges[K, V](
+ kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
+ kafkaStream.generatedRDDs.mapValues { rdd =>
+ rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges
+ }.toSeq.sortBy { _._1 }
+ }
+}
+
+object DirectKafkaStreamSuite {
+ val collectedData = new mutable.ArrayBuffer[String]()
+ var total = -1L
+}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index e57c8f6987fdc..fc9275b7207be 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -19,33 +19,29 @@ package org.apache.spark.streaming.kafka
import scala.util.Random
-import org.scalatest.BeforeAndAfter
import kafka.common.TopicAndPartition
+import org.scalatest.BeforeAndAfterAll
-class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
- val brokerHost = "localhost"
-
- val kafkaParams = Map("metadata.broker.list" -> s"$brokerHost:$brokerPort")
-
- val kc = new KafkaCluster(kafkaParams)
-
+class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
val topic = "kcsuitetopic" + Random.nextInt(10000)
-
val topicAndPartition = TopicAndPartition(topic, 0)
+ var kc: KafkaCluster = null
- before {
+ override def beforeAll() {
setupKafka()
createTopic(topic)
- produceAndSendMessage(topic, Map("a" -> 1))
+ sendMessages(topic, Map("a" -> 1))
+ kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress"))
}
- after {
+ override def afterAll() {
tearDownKafka()
}
test("metadata apis") {
- val leader = kc.findLeaders(Set(topicAndPartition)).right.get
- assert(leader(topicAndPartition) === (brokerHost, brokerPort), "didn't get leader")
+ val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition)
+ val leaderAddress = s"${leader._1}:${leader._2}"
+ assert(leaderAddress === brokerAddress, "didn't get leader")
val parts = kc.getPartitions(Set(topic)).right.get
assert(parts(topicAndPartition), "didn't get partitions")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala
deleted file mode 100644
index 0891ce344f16a..0000000000000
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka
-
-import scala.util.Random
-import scala.concurrent.duration._
-
-import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.Eventually
-
-import kafka.serializer.StringDecoder
-
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-
-class KafkaDirectStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
- val sparkConf = new SparkConf()
- .setMaster("local[4]")
- .setAppName(this.getClass.getSimpleName)
-
- val brokerHost = "localhost"
-
- val kafkaParams = Map(
- "metadata.broker.list" -> s"$brokerHost:$brokerPort",
- "auto.offset.reset" -> "smallest"
- )
-
- var ssc: StreamingContext = _
-
- before {
- setupKafka()
-
- ssc = new StreamingContext(sparkConf, Milliseconds(500))
- }
-
- after {
- if (ssc != null) {
- ssc.stop()
- }
- tearDownKafka()
- }
-
- test("multi topic stream") {
- val topics = Set("newA", "newB")
- val data = Map("a" -> 7, "b" -> 9)
- topics.foreach { t =>
- createTopic(t)
- produceAndSendMessage(t, data)
- }
- val stream = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
- ssc, kafkaParams, topics)
- var total = 0L;
-
- stream.foreachRDD { rdd =>
- val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
- val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
- val off = offsets(i)
- val all = iter.toSeq
- val partSize = all.size
- val rangeSize = off.untilOffset - off.fromOffset
- all.map { _ =>
- (partSize, rangeSize)
- }.toIterator
- }.collect
- collected.foreach { case (partSize, rangeSize) =>
- assert(partSize === rangeSize, "offset ranges are wrong")
- }
- total += collected.size
- }
- ssc.start()
- eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
- assert(total === data.values.sum * topics.size, "didn't get all messages")
- }
- ssc.stop()
- }
-}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 9b9e3f5fce8bd..6774db854a0d0 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -46,9 +46,9 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
createTopic(topic)
- produceAndSendMessage(topic, sent)
+ sendMessages(topic, sent)
- val kafkaParams = Map("metadata.broker.list" -> s"localhost:$brokerPort",
+ val kafkaParams = Map("metadata.broker.list" -> brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt(10000)}")
val kc = new KafkaCluster(kafkaParams)
@@ -65,14 +65,14 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
val rdd2 = getRdd(kc, Set(topic))
val sent2 = Map("d" -> 1)
- produceAndSendMessage(topic, sent2)
+ sendMessages(topic, sent2)
// this is the "0 messages" case
// make sure we dont get anything, since messages were sent after rdd was defined
assert(rdd2.isDefined)
assert(rdd2.get.count === 0)
val rdd3 = getRdd(kc, Set(topic))
- produceAndSendMessage(topic, Map("extra" -> 22))
+ sendMessages(topic, Map("extra" -> 22))
// this is the "exactly 1 message" case
// make sure we get exactly one message, despite there being lots more available
assert(rdd3.isDefined)
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index f207dc6d4fa04..e4966eebb9b34 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -48,30 +48,41 @@ import org.apache.spark.util.Utils
*/
abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging {
- var zkAddress: String = _
- var zkClient: ZkClient = _
-
private val zkHost = "localhost"
+ private var zkPort: Int = 0
private val zkConnectionTimeout = 6000
private val zkSessionTimeout = 6000
private var zookeeper: EmbeddedZookeeper = _
- private var zkPort: Int = 0
- protected var brokerPort = 9092
+ private val brokerHost = "localhost"
+ private var brokerPort = 9092
private var brokerConf: KafkaConfig = _
private var server: KafkaServer = _
private var producer: Producer[String, String] = _
+ private var zkReady = false
+ private var brokerReady = false
+
+ protected var zkClient: ZkClient = _
+
+ def zkAddress: String = {
+ assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address")
+ s"$zkHost:$zkPort"
+ }
+
+ def brokerAddress: String = {
+ assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address")
+ s"$brokerHost:$brokerPort"
+ }
def setupKafka() {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
// Get the actual zookeeper binding port
zkPort = zookeeper.actualPort
- zkAddress = s"$zkHost:$zkPort"
- logInfo("==================== 0 ====================")
+ zkReady = true
+ logInfo("==================== Zookeeper Started ====================")
- zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout,
- ZKStringSerializer)
- logInfo("==================== 1 ====================")
+ zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
+ logInfo("==================== Zookeeper Client Created ====================")
// Kafka broker startup
var bindSuccess: Boolean = false
@@ -80,9 +91,8 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
val brokerProps = getBrokerConfig()
brokerConf = new KafkaConfig(brokerProps)
server = new KafkaServer(brokerConf)
- logInfo("==================== 2 ====================")
server.startup()
- logInfo("==================== 3 ====================")
+ logInfo("==================== Kafka Broker Started ====================")
bindSuccess = true
} catch {
case e: KafkaException =>
@@ -94,10 +104,13 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
}
Thread.sleep(2000)
- logInfo("==================== 4 ====================")
+ logInfo("==================== Kafka + Zookeeper Ready ====================")
+ brokerReady = true
}
def tearDownKafka() {
+ brokerReady = false
+ zkReady = false
if (producer != null) {
producer.close()
producer = null
@@ -121,26 +134,23 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
}
}
- private def createTestMessage(topic: String, sent: Map[String, Int])
- : Seq[KeyedMessage[String, String]] = {
- val messages = for ((s, freq) <- sent; i <- 0 until freq) yield {
- new KeyedMessage[String, String](topic, s)
- }
- messages.toSeq
- }
-
def createTopic(topic: String) {
AdminUtils.createTopic(zkClient, topic, 1, 1)
- logInfo("==================== 5 ====================")
// wait until metadata is propagated
waitUntilMetadataIsPropagated(topic, 0)
+ logInfo(s"==================== Topic $topic Created ====================")
}
- def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
+ def sendMessages(topic: String, messageToFreq: Map[String, Int]) {
+ val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
+ sendMessages(topic, messages)
+ }
+
+ def sendMessages(topic: String, messages: Array[String]) {
producer = new Producer[String, String](new ProducerConfig(getProducerConfig()))
- producer.send(createTestMessage(topic, sent): _*)
+ producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*)
producer.close()
- logInfo("==================== 6 ====================")
+ logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================")
}
private def getBrokerConfig(): Properties = {
@@ -218,7 +228,7 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
createTopic(topic)
- produceAndSendMessage(topic, sent)
+ sendMessages(topic, sent)
val kafkaParams = Map("zookeeper.connect" -> zkAddress,
"group.id" -> s"test-consumer-${Random.nextInt(10000)}",
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 64ccc92c81fa9..fc53c23abda85 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -79,7 +79,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
test("Reliable Kafka input stream with single topic") {
var topic = "test-topic"
createTopic(topic)
- produceAndSendMessage(topic, data)
+ sendMessages(topic, data)
// Verify whether the offset of this group/topic/partition is 0 before starting.
assert(getCommitOffset(groupId, topic, 0) === None)
@@ -111,7 +111,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1)
topics.foreach { case (t, _) =>
createTopic(t)
- produceAndSendMessage(t, data)
+ sendMessages(t, data)
}
// Before started, verify all the group/topic/partition offsets are 0.
From 2d1e916730492f5d61b97da6c483d3223ca44315 Mon Sep 17 00:00:00 2001
From: Sean Owen
Date: Tue, 10 Feb 2015 09:19:01 +0000
Subject: [PATCH 035/817] SPARK-5239 [CORE] JdbcRDD throws
"java.lang.AbstractMethodError: oracle.jdbc.driver.xxxxxx.isClosed()Z"
This is a completion of https://github.com/apache/spark/pull/4033 which was withdrawn for some reason.
Author: Sean Owen
Closes #4470 from srowen/SPARK-5239.2 and squashes the following commits:
2398bde [Sean Owen] Avoid use of JDBC4-only isClosed()
---
core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala | 6 +++---
.../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 6 +++---
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 642a12c1edf6c..4fe7622bda00f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -99,21 +99,21 @@ class JdbcRDD[T: ClassTag](
override def close() {
try {
- if (null != rs && ! rs.isClosed()) {
+ if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
- if (null != stmt && ! stmt.isClosed()) {
+ if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
- if (null != conn && ! conn.isClosed()) {
+ if (null != conn) {
conn.close()
}
logInfo("closed connection")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 0bec32cca1325..87304ce2496b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -370,21 +370,21 @@ private[sql] class JDBCRDD(
def close() {
if (closed) return
try {
- if (null != rs && ! rs.isClosed()) {
+ if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
- if (null != stmt && ! stmt.isClosed()) {
+ if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
- if (null != conn && ! conn.isClosed()) {
+ if (null != conn) {
conn.close()
}
logInfo("closed connection")
From ba667935f8670293f10b8bbe1e317b28d00f9875 Mon Sep 17 00:00:00 2001
From: Cheng Lian
Date: Tue, 10 Feb 2015 02:28:47 -0800
Subject: [PATCH 036/817] [SPARK-5700] [SQL] [Build] Bumps jets3t to 0.9.3 for
hadoop-2.3 and hadoop-2.4 profiles
This is a follow-up PR for #4454 and #4484. JetS3t 0.9.2 contains a log4j.properties file inside the artifact and breaks our tests (see SPARK-5696). This is fixed in 0.9.3.
This PR also reverts hotfix changes introduced in #4484. The reason is that asking users to configure HiveThriftServer2 logging configurations in hive-log4j.properties can be unintuitive.
[
](https://reviewable.io/reviews/apache/spark/4499)
Author: Cheng Lian
Closes #4499 from liancheng/spark-5700 and squashes the following commits:
4f020c7 [Cheng Lian] Bumps jets3t to 0.9.3 for hadoop-2.3 and hadoop-2.4 profiles
---
pom.xml | 4 ++--
.../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 3 ---
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/pom.xml b/pom.xml
index a9e968af25453..56e37d42265c0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1578,7 +1578,7 @@
2.3.0
2.5.0
- 0.9.2
+ 0.9.3
0.98.7-hadoop2
3.1.1
hadoop2
@@ -1591,7 +1591,7 @@
2.4.0
2.5.0
- 0.9.2
+ 0.9.3
0.98.7-hadoop2
3.1.1
hadoop2
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 525777aa454c4..6e07df18b0e15 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.hive.thriftserver
import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.hive.common.LogUtils
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
@@ -55,8 +54,6 @@ object HiveThriftServer2 extends Logging {
System.exit(-1)
}
- LogUtils.initHiveLog4j()
-
logInfo("Starting SparkContext")
SparkSQLEnv.init()
From 50820f15277187e8522520a3ecae412abbdb4397 Mon Sep 17 00:00:00 2001
From: Nicholas Chammas
Date: Tue, 10 Feb 2015 15:45:38 +0000
Subject: [PATCH 037/817] [SPARK-1805] [EC2] Validate instance types
Addresses [SPARK-1805](https://issues.apache.org/jira/browse/SPARK-1805), though doesn't resolve it completely.
Error out quickly if the user asks for the master and slaves to have different AMI virtualization types, since we don't currently support that.
In addition to that, we print warnings if the inputted instance types are not recognized, though I would prefer if we errored out. Elsewhere in the script it seems [we allow unrecognized instance types](https://github.com/apache/spark/blob/5de14cc2763a8211f77eeb55940dec025822eb78/ec2/spark_ec2.py#L331), though I think we should remove that.
It's messy, but it should serve us until we enhance spark-ec2 to support clusters with mixed virtualization types.
Author: Nicholas Chammas
Closes #4455 from nchammas/ec2-master-slave-different-virtualization and squashes the following commits:
ce28609 [Nicholas Chammas] fix style
b0adba0 [Nicholas Chammas] validate input instance types
---
ec2/spark_ec2.py | 132 +++++++++++++++++++++++++++++------------------
1 file changed, 81 insertions(+), 51 deletions(-)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 3e4c49c0e1db6..fe510f12bcec6 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -112,6 +112,7 @@ def parse_args():
version="%prog {v}".format(v=SPARK_EC2_VERSION),
usage="%prog [options] \n\n"
+ " can be: launch, destroy, login, stop, start, get-master, reboot-slaves")
+
parser.add_option(
"-s", "--slaves", type="int", default=1,
help="Number of slaves to launch (default: %default)")
@@ -139,7 +140,9 @@ def parse_args():
help="Availability zone to launch instances in, or 'all' to spread " +
"slaves across multiple (an additional $0.01/Gb for bandwidth" +
"between zones applies) (default: a single zone chosen at random)")
- parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
+ parser.add_option(
+ "-a", "--ami",
+ help="Amazon Machine Image ID to use")
parser.add_option(
"-v", "--spark-version", default=DEFAULT_SPARK_VERSION,
help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)")
@@ -179,10 +182,11 @@ def parse_args():
"Only possible on EBS-backed AMIs. " +
"EBS volumes are only attached if --ebs-vol-size > 0." +
"Only support up to 8 EBS volumes.")
- parser.add_option("--placement-group", type="string", default=None,
- help="Which placement group to try and launch " +
- "instances into. Assumes placement group is already " +
- "created.")
+ parser.add_option(
+ "--placement-group", type="string", default=None,
+ help="Which placement group to try and launch " +
+ "instances into. Assumes placement group is already " +
+ "created.")
parser.add_option(
"--swap", metavar="SWAP", type="int", default=1024,
help="Swap space to set up per node, in MB (default: %default)")
@@ -226,9 +230,11 @@ def parse_args():
"--copy-aws-credentials", action="store_true", default=False,
help="Add AWS credentials to hadoop configuration to allow Spark to access S3")
parser.add_option(
- "--subnet-id", default=None, help="VPC subnet to launch instances in")
+ "--subnet-id", default=None,
+ help="VPC subnet to launch instances in")
parser.add_option(
- "--vpc-id", default=None, help="VPC to launch instances in")
+ "--vpc-id", default=None,
+ help="VPC to launch instances in")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -290,52 +296,54 @@ def is_active(instance):
return (instance.state in ['pending', 'running', 'stopping', 'stopped'])
-# Attempt to resolve an appropriate AMI given the architecture and region of the request.
# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/
# Last Updated: 2014-06-20
# For easy maintainability, please keep this manually-inputted dictionary sorted by key.
+EC2_INSTANCE_TYPES = {
+ "c1.medium": "pvm",
+ "c1.xlarge": "pvm",
+ "c3.2xlarge": "pvm",
+ "c3.4xlarge": "pvm",
+ "c3.8xlarge": "pvm",
+ "c3.large": "pvm",
+ "c3.xlarge": "pvm",
+ "cc1.4xlarge": "hvm",
+ "cc2.8xlarge": "hvm",
+ "cg1.4xlarge": "hvm",
+ "cr1.8xlarge": "hvm",
+ "hi1.4xlarge": "pvm",
+ "hs1.8xlarge": "pvm",
+ "i2.2xlarge": "hvm",
+ "i2.4xlarge": "hvm",
+ "i2.8xlarge": "hvm",
+ "i2.xlarge": "hvm",
+ "m1.large": "pvm",
+ "m1.medium": "pvm",
+ "m1.small": "pvm",
+ "m1.xlarge": "pvm",
+ "m2.2xlarge": "pvm",
+ "m2.4xlarge": "pvm",
+ "m2.xlarge": "pvm",
+ "m3.2xlarge": "hvm",
+ "m3.large": "hvm",
+ "m3.medium": "hvm",
+ "m3.xlarge": "hvm",
+ "r3.2xlarge": "hvm",
+ "r3.4xlarge": "hvm",
+ "r3.8xlarge": "hvm",
+ "r3.large": "hvm",
+ "r3.xlarge": "hvm",
+ "t1.micro": "pvm",
+ "t2.medium": "hvm",
+ "t2.micro": "hvm",
+ "t2.small": "hvm",
+}
+
+
+# Attempt to resolve an appropriate AMI given the architecture and region of the request.
def get_spark_ami(opts):
- instance_types = {
- "c1.medium": "pvm",
- "c1.xlarge": "pvm",
- "c3.2xlarge": "pvm",
- "c3.4xlarge": "pvm",
- "c3.8xlarge": "pvm",
- "c3.large": "pvm",
- "c3.xlarge": "pvm",
- "cc1.4xlarge": "hvm",
- "cc2.8xlarge": "hvm",
- "cg1.4xlarge": "hvm",
- "cr1.8xlarge": "hvm",
- "hi1.4xlarge": "pvm",
- "hs1.8xlarge": "pvm",
- "i2.2xlarge": "hvm",
- "i2.4xlarge": "hvm",
- "i2.8xlarge": "hvm",
- "i2.xlarge": "hvm",
- "m1.large": "pvm",
- "m1.medium": "pvm",
- "m1.small": "pvm",
- "m1.xlarge": "pvm",
- "m2.2xlarge": "pvm",
- "m2.4xlarge": "pvm",
- "m2.xlarge": "pvm",
- "m3.2xlarge": "hvm",
- "m3.large": "hvm",
- "m3.medium": "hvm",
- "m3.xlarge": "hvm",
- "r3.2xlarge": "hvm",
- "r3.4xlarge": "hvm",
- "r3.8xlarge": "hvm",
- "r3.large": "hvm",
- "r3.xlarge": "hvm",
- "t1.micro": "pvm",
- "t2.medium": "hvm",
- "t2.micro": "hvm",
- "t2.small": "hvm",
- }
- if opts.instance_type in instance_types:
- instance_type = instance_types[opts.instance_type]
+ if opts.instance_type in EC2_INSTANCE_TYPES:
+ instance_type = EC2_INSTANCE_TYPES[opts.instance_type]
else:
instance_type = "pvm"
print >> stderr,\
@@ -605,8 +613,6 @@ def launch_cluster(conn, opts, cluster_name):
# Get the EC2 instances in an existing cluster if available.
# Returns a tuple of lists of EC2 instance objects for the masters and slaves
-
-
def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
print "Searching for existing cluster " + cluster_name + "..."
reservations = conn.get_all_reservations()
@@ -1050,6 +1056,30 @@ def real_main():
print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file)
sys.exit(1)
+ if opts.instance_type not in EC2_INSTANCE_TYPES:
+ print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
+ t=opts.instance_type)
+
+ if opts.master_instance_type != "":
+ if opts.master_instance_type not in EC2_INSTANCE_TYPES:
+ print >> stderr, \
+ "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
+ t=opts.master_instance_type)
+ # Since we try instance types even if we can't resolve them, we check if they resolve first
+ # and, if they do, see if they resolve to the same virtualization type.
+ if opts.instance_type in EC2_INSTANCE_TYPES and \
+ opts.master_instance_type in EC2_INSTANCE_TYPES:
+ if EC2_INSTANCE_TYPES[opts.instance_type] != \
+ EC2_INSTANCE_TYPES[opts.master_instance_type]:
+ print >> stderr, \
+ "Error: spark-ec2 currently does not support having a master and slaves with " + \
+ "different AMI virtualization types."
+ print >> stderr, "master instance virtualization type: {t}".format(
+ t=EC2_INSTANCE_TYPES[opts.master_instance_type])
+ print >> stderr, "slave instance virtualization type: {t}".format(
+ t=EC2_INSTANCE_TYPES[opts.instance_type])
+ sys.exit(1)
+
if opts.ebs_vol_num > 8:
print >> stderr, "ebs-vol-num cannot be greater than 8"
sys.exit(1)
From 6cc96cf0c3ea87ab65d42a59725959d94701577b Mon Sep 17 00:00:00 2001
From: JqueryFan
Date: Tue, 10 Feb 2015 17:37:32 +0000
Subject: [PATCH 038/817] [Spark-5717] [MLlib] add stop and reorganize import
Trivial. add sc stop and reorganize import
https://issues.apache.org/jira/browse/SPARK-5717
Author: JqueryFan
Author: Yuhao Yang
Closes #4503 from hhbyyh/scstop and squashes the following commits:
7837a2c [JqueryFan] revert import change
2e85cc1 [Yuhao Yang] add stop and reorganize import
---
.../java/org/apache/spark/examples/mllib/JavaLDAExample.java | 1 +
.../main/scala/org/apache/spark/examples/mllib/LDAExample.scala | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
index f394ff2084463..36207ae38d9a9 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
@@ -71,5 +71,6 @@ public Tuple2 call(Tuple2 doc_id) {
}
System.out.println();
}
+ sc.stop();
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 0e1b27a8bd2ee..11399a7633638 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -159,7 +159,7 @@ object LDAExample {
}
println()
}
-
+ sc.stop()
}
/**
From c7ad80ae4256c88e380e7488d48cf6eb14a92d76 Mon Sep 17 00:00:00 2001
From: Daoyuan Wang
Date: Tue, 10 Feb 2015 11:08:21 -0800
Subject: [PATCH 039/817] [SPARK-5716] [SQL] Support TOK_CHARSETLITERAL in
HiveQl
Author: Daoyuan Wang
Closes #4502 from adrian-wang/utf8 and squashes the following commits:
4d7b0ee [Daoyuan Wang] remove useless import
606f981 [Daoyuan Wang] support TOK_CHARSETLITERAL in HiveQl
---
.../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 1 +
.../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 4 ++++
.../golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 | 0
.../golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 | 0
.../golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b | 1 +
.../golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce | 1 +
.../golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 | 1 +
7 files changed, 8 insertions(+)
create mode 100644 sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7
create mode 100644 sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75
create mode 100644 sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b
create mode 100644 sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce
create mode 100644 sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index a6266f611c219..e443e5bd5f54d 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -518,6 +518,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"inputddl2",
"inputddl3",
"inputddl4",
+ "inputddl5",
"inputddl6",
"inputddl7",
"inputddl8",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f51af62d3340b..969868aef2917 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import java.sql.Date
+
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
@@ -1237,6 +1238,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL =>
Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1)))
+ case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL =>
+ Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText))
+
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
diff --git a/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 b/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 b/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b
new file mode 100644
index 0000000000000..518a70918b2c7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b
@@ -0,0 +1 @@
+name string
diff --git a/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce
new file mode 100644
index 0000000000000..33398360345d7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce
@@ -0,0 +1 @@
+邵铮
diff --git a/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783
@@ -0,0 +1 @@
+1
From 69bc3bb6cffe82aee5ecd0b09410a847ba486b15 Mon Sep 17 00:00:00 2001
From: Sandy Ryza
Date: Tue, 10 Feb 2015 11:07:25 -0800
Subject: [PATCH 040/817] SPARK-4136. Under dynamic allocation, cancel
outstanding executor requests when no longer needed
This takes advantage of the changes made in SPARK-4337 to cancel pending requests to YARN when they are no longer needed.
Each time the timer in `ExecutorAllocationManager` strikes, we compute `maxNumNeededExecutors`, the maximum number of executors we could fill with the current load. This is calculated as the total number of running and pending tasks divided by the number of cores per executor. If `maxNumNeededExecutors` is below the total number of running and pending executors, we call `requestTotalExecutors(maxNumNeededExecutors)` to let the cluster manager know that it should cancel any pending requests above this amount. If not, `maxNumNeededExecutors` is just used as a bound in alongside the configured `maxExecutors` to limit the number of new requests.
The patch modifies the API exposed by `ExecutorAllocationClient` for requesting additional executors by moving from `requestExecutors` to `requestTotalExecutors`. This makes the communication between the `ExecutorAllocationManager` and the `YarnAllocator` easier to reason about and removes some state that needed to be kept in the `CoarseGrainedSchedulerBackend`. I think an argument can be made that this makes for a less attractive user-facing API in `SparkContext`, but I'm having trouble envisioning situations where a user would want to use either of these APIs.
This will likely break some tests, but I wanted to get feedback on the approach before adding tests and polishing.
Author: Sandy Ryza
Closes #4168 from sryza/sandy-spark-4136 and squashes the following commits:
37ce77d [Sandy Ryza] Warn on negative number
cd3b2ff [Sandy Ryza] SPARK-4136
---
.../spark/ExecutorAllocationClient.scala | 8 +
.../spark/ExecutorAllocationManager.scala | 149 ++++++++++++------
.../scala/org/apache/spark/SparkContext.scala | 21 ++-
.../CoarseGrainedSchedulerBackend.scala | 20 ++-
.../ExecutorAllocationManagerSuite.scala | 36 ++++-
5 files changed, 184 insertions(+), 50 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index a46a81eabd965..079055e00c6c3 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -19,9 +19,17 @@ package org.apache.spark
/**
* A client that communicates with the cluster manager to request or kill executors.
+ * This is currently supported only in YARN mode.
*/
private[spark] trait ExecutorAllocationClient {
+ /**
+ * Express a preference to the cluster manager for a given total number of executors.
+ * This can result in canceling pending requests or filing additional requests.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ private[spark] def requestTotalExecutors(numExecutors: Int): Boolean
+
/**
* Request an additional number of executors from the cluster manager.
* Return whether the request is acknowledged by the cluster manager.
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 02d54bf3b53cc..998695b6ac8ab 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -201,18 +201,34 @@ private[spark] class ExecutorAllocationManager(
}
/**
- * If the add time has expired, request new executors and refresh the add time.
- * If the remove time for an existing executor has expired, kill the executor.
+ * The number of executors we would have if the cluster manager were to fulfill all our existing
+ * requests.
+ */
+ private def targetNumExecutors(): Int =
+ numExecutorsPending + executorIds.size - executorsPendingToRemove.size
+
+ /**
+ * The maximum number of executors we would need under the current load to satisfy all running
+ * and pending tasks, rounded up.
+ */
+ private def maxNumExecutorsNeeded(): Int = {
+ val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks
+ (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor
+ }
+
+ /**
+ * This is called at a fixed interval to regulate the number of pending executor requests
+ * and number of executors running.
+ *
+ * First, adjust our requested executors based on the add time and our current needs.
+ * Then, if the remove time for an existing executor has expired, kill the executor.
+ *
* This is factored out into its own method for testing.
*/
private def schedule(): Unit = synchronized {
val now = clock.getTimeMillis
- if (addTime != NOT_SET && now >= addTime) {
- addExecutors()
- logDebug(s"Starting timer to add more executors (to " +
- s"expire in $sustainedSchedulerBacklogTimeout seconds)")
- addTime += sustainedSchedulerBacklogTimeout * 1000
- }
+
+ addOrCancelExecutorRequests(now)
removeTimes.retain { case (executorId, expireTime) =>
val expired = now >= expireTime
@@ -223,59 +239,89 @@ private[spark] class ExecutorAllocationManager(
}
}
+ /**
+ * Check to see whether our existing allocation and the requests we've made previously exceed our
+ * current needs. If so, let the cluster manager know so that it can cancel pending requests that
+ * are unneeded.
+ *
+ * If not, and the add time has expired, see if we can request new executors and refresh the add
+ * time.
+ *
+ * @return the delta in the target number of executors.
+ */
+ private def addOrCancelExecutorRequests(now: Long): Int = synchronized {
+ val currentTarget = targetNumExecutors
+ val maxNeeded = maxNumExecutorsNeeded
+
+ if (maxNeeded < currentTarget) {
+ // The target number exceeds the number we actually need, so stop adding new
+ // executors and inform the cluster manager to cancel the extra pending requests.
+ val newTotalExecutors = math.max(maxNeeded, minNumExecutors)
+ client.requestTotalExecutors(newTotalExecutors)
+ numExecutorsToAdd = 1
+ updateNumExecutorsPending(newTotalExecutors)
+ } else if (addTime != NOT_SET && now >= addTime) {
+ val delta = addExecutors(maxNeeded)
+ logDebug(s"Starting timer to add more executors (to " +
+ s"expire in $sustainedSchedulerBacklogTimeout seconds)")
+ addTime += sustainedSchedulerBacklogTimeout * 1000
+ delta
+ } else {
+ 0
+ }
+ }
+
/**
* Request a number of executors from the cluster manager.
* If the cap on the number of executors is reached, give up and reset the
* number of executors to add next round instead of continuing to double it.
- * Return the number actually requested.
+ *
+ * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending
+ * tasks could fill
+ * @return the number of additional executors actually requested.
*/
- private def addExecutors(): Int = synchronized {
- // Do not request more executors if we have already reached the upper bound
- val numExistingExecutors = executorIds.size + numExecutorsPending
- if (numExistingExecutors >= maxNumExecutors) {
+ private def addExecutors(maxNumExecutorsNeeded: Int): Int = {
+ // Do not request more executors if it would put our target over the upper bound
+ val currentTarget = targetNumExecutors
+ if (currentTarget >= maxNumExecutors) {
logDebug(s"Not adding executors because there are already ${executorIds.size} " +
s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)")
numExecutorsToAdd = 1
return 0
}
- // The number of executors needed to satisfy all pending tasks is the number of tasks pending
- // divided by the number of tasks each executor can fit, rounded up.
- val maxNumExecutorsPending =
- (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor
- if (numExecutorsPending >= maxNumExecutorsPending) {
- logDebug(s"Not adding executors because there are already $numExecutorsPending " +
- s"pending and pending tasks could only fill $maxNumExecutorsPending")
- numExecutorsToAdd = 1
- return 0
- }
-
- // It's never useful to request more executors than could satisfy all the pending tasks, so
- // cap request at that amount.
- // Also cap request with respect to the configured upper bound.
- val maxNumExecutorsToAdd = math.min(
- maxNumExecutorsPending - numExecutorsPending,
- maxNumExecutors - numExistingExecutors)
- assert(maxNumExecutorsToAdd > 0)
-
- val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
-
- val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
- val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd)
+ val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded)
+ val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors)
+ val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors)
if (addRequestAcknowledged) {
- logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
- s"tasks are backlogged (new desired total will be $newTotalExecutors)")
- numExecutorsToAdd =
- if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1
- numExecutorsPending += actualNumExecutorsToAdd
- actualNumExecutorsToAdd
+ val delta = updateNumExecutorsPending(newTotalExecutors)
+ logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" +
+ s" (new desired total will be $newTotalExecutors)")
+ numExecutorsToAdd = if (delta == numExecutorsToAdd) {
+ numExecutorsToAdd * 2
+ } else {
+ 1
+ }
+ delta
} else {
- logWarning(s"Unable to reach the cluster manager " +
- s"to request $actualNumExecutorsToAdd executors!")
+ logWarning(
+ s"Unable to reach the cluster manager to request $newTotalExecutors total executors!")
0
}
}
+ /**
+ * Given the new target number of executors, update the number of pending executor requests,
+ * and return the delta from the old number of pending requests.
+ */
+ private def updateNumExecutorsPending(newTotalExecutors: Int): Int = {
+ val newNumExecutorsPending =
+ newTotalExecutors - executorIds.size + executorsPendingToRemove.size
+ val delta = newNumExecutorsPending - numExecutorsPending
+ numExecutorsPending = newNumExecutorsPending
+ delta
+ }
+
/**
* Request the cluster manager to remove the given executor.
* Return whether the request is received.
@@ -415,6 +461,8 @@ private[spark] class ExecutorAllocationManager(
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
+ // Number of tasks currently running on the cluster. Should be 0 when no stages are active.
+ private var numRunningTasks: Int = _
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
val stageId = stageSubmitted.stageInfo.stageId
@@ -435,6 +483,10 @@ private[spark] class ExecutorAllocationManager(
// This is needed in case the stage is aborted for any reason
if (stageIdToNumTasks.isEmpty) {
allocationManager.onSchedulerQueueEmpty()
+ if (numRunningTasks != 0) {
+ logWarning("No stages are running, but numRunningTasks != 0")
+ numRunningTasks = 0
+ }
}
}
}
@@ -446,6 +498,7 @@ private[spark] class ExecutorAllocationManager(
val executorId = taskStart.taskInfo.executorId
allocationManager.synchronized {
+ numRunningTasks += 1
// This guards against the race condition in which the `SparkListenerTaskStart`
// event is posted before the `SparkListenerBlockManagerAdded` event, which is
// possible because these events are posted in different threads. (see SPARK-4951)
@@ -475,7 +528,8 @@ private[spark] class ExecutorAllocationManager(
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
allocationManager.synchronized {
- // If the executor is no longer running scheduled any tasks, mark it as idle
+ numRunningTasks -= 1
+ // If the executor is no longer running any scheduled tasks, mark it as idle
if (executorIdToTaskIds.contains(executorId)) {
executorIdToTaskIds(executorId) -= taskId
if (executorIdToTaskIds(executorId).isEmpty) {
@@ -514,6 +568,11 @@ private[spark] class ExecutorAllocationManager(
}.sum
}
+ /**
+ * The number of tasks currently running across all stages.
+ */
+ def totalRunningTasks(): Int = numRunningTasks
+
/**
* Return true if an executor is not currently running a task, and false otherwise.
*
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8d3c3d000adf3..04ca5d1019e4b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1103,10 +1103,27 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
listenerBus.addListener(listener)
}
+ /**
+ * Express a preference to the cluster manager for a given total number of executors.
+ * This can result in canceling pending requests or filing additional requests.
+ * This is currently only supported in YARN mode. Return whether the request is received.
+ */
+ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Requesting executors is currently only supported in YARN mode")
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.requestTotalExecutors(numExecutors)
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
/**
* :: DeveloperApi ::
* Request an additional number of executors from the cluster manager.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
@@ -1124,7 +1141,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f9ca93432bf41..99986c32b0fde 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -311,7 +311,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
/**
* Request an additional number of executors from the cluster manager.
- * Return whether the request is acknowledged.
+ * @return whether the request is acknowledged.
*/
final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
if (numAdditionalExecutors < 0) {
@@ -327,6 +327,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
doRequestTotalExecutors(newTotal)
}
+ /**
+ * Express a preference to the cluster manager for a given total number of executors. This can
+ * result in canceling pending requests or filing additional requests.
+ * @return whether the request is acknowledged.
+ */
+ final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized {
+ if (numAdditionalExecutors < 0) {
+ throw new IllegalArgumentException(
+ "Attempted to request a negative number of executor(s) " +
+ s"$numExecutors from the cluster manager. Please specify a positive number!")
+ }
+ numPendingExecutors =
+ math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
+ doRequestTotalExecutors(numExecutors)
+ }
+
/**
* Request executors from the cluster manager by specifying the total number desired,
* including existing pending and running executors.
@@ -337,7 +353,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* insufficient resources to satisfy the first request. We make the assumption here that the
* cluster manager will eventually fulfill all requests when resources free up.
*
- * Return whether the request is acknowledged.
+ * @return whether the request is acknowledged.
*/
protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 9eb87f016068d..5d96eabd34eee 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -175,6 +175,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
assert(numExecutorsPending(manager) === 9)
}
+ test("cancel pending executors when no longer needed") {
+ sc = createSparkContext(1, 10)
+ val manager = sc.executorAllocationManager.get
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5)))
+
+ assert(numExecutorsPending(manager) === 0)
+ assert(numExecutorsToAdd(manager) === 1)
+ assert(addExecutors(manager) === 1)
+ assert(numExecutorsPending(manager) === 1)
+ assert(numExecutorsToAdd(manager) === 2)
+ assert(addExecutors(manager) === 2)
+ assert(numExecutorsPending(manager) === 3)
+
+ val task1Info = createTaskInfo(0, 0, "executor-1")
+ sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info))
+
+ assert(numExecutorsToAdd(manager) === 4)
+ assert(addExecutors(manager) === 2)
+
+ val task2Info = createTaskInfo(1, 0, "executor-1")
+ sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info))
+ sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null))
+ sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null))
+
+ assert(adjustRequestedExecutors(manager) === -1)
+ }
+
test("remove executors") {
sc = createSparkContext(5, 10)
val manager = sc.executorAllocationManager.get
@@ -679,6 +706,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd)
private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending)
+ private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded)
private val _executorsPendingToRemove =
PrivateMethod[collection.Set[String]]('executorsPendingToRemove)
private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds)
@@ -686,6 +714,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes)
private val _schedule = PrivateMethod[Unit]('schedule)
private val _addExecutors = PrivateMethod[Int]('addExecutors)
+ private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests)
private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor)
private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded)
private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved)
@@ -724,7 +753,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
}
private def addExecutors(manager: ExecutorAllocationManager): Int = {
- manager invokePrivate _addExecutors()
+ val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded()
+ manager invokePrivate _addExecutors(maxNumExecutorsNeeded)
+ }
+
+ private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = {
+ manager invokePrivate _addOrCancelExecutorRequests(0L)
}
private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = {
From b640c841fca92bb0bca77267db2965ff8f79586f Mon Sep 17 00:00:00 2001
From: Andrew Or
Date: Tue, 10 Feb 2015 11:18:01 -0800
Subject: [PATCH 041/817] [HOTFIX][SPARK-4136] Fix compilation and tests
---
.../org/apache/spark/ExecutorAllocationClient.scala | 8 ++++----
.../cluster/CoarseGrainedSchedulerBackend.scala | 2 +-
.../apache/spark/ExecutorAllocationManagerSuite.scala | 10 ++++------
3 files changed, 9 insertions(+), 11 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index 079055e00c6c3..443830f8d03b6 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -26,25 +26,25 @@ private[spark] trait ExecutorAllocationClient {
/**
* Express a preference to the cluster manager for a given total number of executors.
* This can result in canceling pending requests or filing additional requests.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
private[spark] def requestTotalExecutors(numExecutors: Int): Boolean
/**
* Request an additional number of executors from the cluster manager.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def requestExecutors(numAdditionalExecutors: Int): Boolean
/**
* Request that the cluster manager kill the specified executors.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def killExecutors(executorIds: Seq[String]): Boolean
/**
* Request that the cluster manager kill the specified executor.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 99986c32b0fde..6f77fa32ce37b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -333,7 +333,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* @return whether the request is acknowledged.
*/
final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized {
- if (numAdditionalExecutors < 0) {
+ if (numExecutors < 0) {
throw new IllegalArgumentException(
"Attempted to request a negative number of executor(s) " +
s"$numExecutors from the cluster manager. Please specify a positive number!")
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 5d96eabd34eee..d3123e854016b 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -297,15 +297,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
assert(removeExecutor(manager, "5"))
assert(removeExecutor(manager, "6"))
assert(executorIds(manager).size === 10)
- assert(addExecutors(manager) === 0) // still at upper limit
+ assert(addExecutors(manager) === 1)
onExecutorRemoved(manager, "3")
onExecutorRemoved(manager, "4")
assert(executorIds(manager).size === 8)
// Add succeeds again, now that we are no longer at the upper limit
// Number of executors added restarts at 1
- assert(addExecutors(manager) === 1)
- assert(addExecutors(manager) === 1) // upper limit reached again
+ assert(addExecutors(manager) === 2)
+ assert(addExecutors(manager) === 1) // upper limit reached
assert(addExecutors(manager) === 0)
assert(executorIds(manager).size === 8)
onExecutorRemoved(manager, "5")
@@ -313,9 +313,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
onExecutorAdded(manager, "13")
onExecutorAdded(manager, "14")
assert(executorIds(manager).size === 8)
- assert(addExecutors(manager) === 1)
- assert(addExecutors(manager) === 1) // upper limit reached again
- assert(addExecutors(manager) === 0)
+ assert(addExecutors(manager) === 0) // still at upper limit
onExecutorAdded(manager, "15")
onExecutorAdded(manager, "16")
assert(executorIds(manager).size === 10)
From 59272dad77eb95c5ae8e0652e00d02a2675cda53 Mon Sep 17 00:00:00 2001
From: wangfei
Date: Tue, 10 Feb 2015 11:54:30 -0800
Subject: [PATCH 042/817] [SPARK-5592][SQL] java.net.URISyntaxException when
insert data to a partitioned table
flowing sql get URISyntaxException:
```
create table sc as select *
from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows)
union all
select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows)
union all
select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s;
create table sc_part (key string) partitioned by (ts string) stored as rcfile;
set hive.exec.dynamic.partition=true;
set hive.exec.dynamic.partition.mode=nonstrict;
insert overwrite table sc_part partition(ts) select * from sc;
```
java.net.URISyntaxException: Relative path in absolute URI: ts=2011-01-11+15:18:26
at org.apache.hadoop.fs.Path.initialize(Path.java:206)
at org.apache.hadoop.fs.Path.(Path.java:172)
at org.apache.hadoop.fs.Path.(Path.java:94)
at org.apache.spark.sql.hive.SparkHiveDynamicPartitionWriterContainer.org$apache$spark$sql$hive$SparkHiveDynamicPartitionWriterContainer$$newWriter$1(hiveWriterContainers.scala:230)
at org.apache.spark.sql.hive.SparkHiveDynamicPartitionWriterContainer$$anonfun$getLocalFileWriter$1.apply(hiveWriterContainers.scala:243)
at org.apache.spark.sql.hive.SparkHiveDynamicPartitionWriterContainer$$anonfun$getLocalFileWriter$1.apply(hiveWriterContainers.scala:243)
at scala.collection.mutable.MapLike$class.getOrElseUpdate(MapLike.scala:189)
at scala.collection.mutable.AbstractMap.getOrElseUpdate(Map.scala:91)
at org.apache.spark.sql.hive.SparkHiveDynamicPartitionWriterContainer.getLocalFileWriter(hiveWriterContainers.scala:243)
at org.apache.spark.sql.hive.execution.InsertIntoHiveTable$$anonfun$org$apache$spark$sql$hive$execution$InsertIntoHiveTable$$writeToFile$1$1.apply(InsertIntoHiveTable.scala:113)
at org.apache.spark.sql.hive.execution.InsertIntoHiveTable$$anonfun$org$apache$spark$sql$hive$execution$InsertIntoHiveTable$$writeToFile$1$1.apply(InsertIntoHiveTable.scala:105)
at scala.collection.Iterator$class.foreach(Iterator.scala:727)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
at org.apache.spark.sql.hive.execution.InsertIntoHiveTable.org$apache$spark$sql$hive$execution$InsertIntoHiveTable$$writeToFile$1(InsertIntoHiveTable.scala:105)
at org.apache.spark.sql.hive.execution.InsertIntoHiveTable$$anonfun$saveAsHiveFile$3.apply(InsertIntoHiveTable.scala:87)
at org.apache.spark.sql.hive.execution.InsertIntoHiveTable$$anonfun$saveAsHiveFile$3.apply(InsertIntoHiveTable.scala:87)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)
at org.apache.spark.scheduler.Task.run(Task.scala:64)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:194)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1110)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:603)
at java.lang.Thread.run(Thread.java:722)
Caused by: java.net.URISyntaxException: Relative path in absolute URI: ts=2011-01-11+15:18:26
at java.net.URI.checkPath(URI.java:1804)
at java.net.URI.(URI.java:752)
at org.apache.hadoop.fs.Path.initialize(Path.java:203)
Author: wangfei
Author: Fei Wang
Closes #4368 from scwf/SPARK-5592 and squashes the following commits:
aa55ef4 [Fei Wang] comments addressed
f8f8bb1 [wangfei] added test case
f24624f [wangfei] Merge branch 'master' of https://github.com/apache/spark into SPARK-5592
9998177 [wangfei] added test case
ea81daf [wangfei] fix URISyntaxException
---
.../spark/sql/hive/hiveWriterContainers.scala | 12 +++++++++---
.../sql/hive/execution/HiveQuerySuite.scala | 16 ++++++++++++++++
2 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index aae175e426ade..f136e43acc8f2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred._
+import org.apache.hadoop.hive.common.FileUtils
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.Row
@@ -212,9 +213,14 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
.zip(row.toSeq.takeRight(dynamicPartColNames.length))
.map { case (col, rawVal) =>
val string = if (rawVal == null) null else String.valueOf(rawVal)
- s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}"
- }
- .mkString
+ val colString =
+ if (string == null || string.isEmpty) {
+ defaultPartName
+ } else {
+ FileUtils.escapePathName(string)
+ }
+ s"/$col=$colString"
+ }.mkString
def newWriter = {
val newFileSinkDesc = new FileSinkDesc(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 27047ce4b1b0b..405b200d05412 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -859,6 +859,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
+ test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") {
+ sql("""
+ |create table sc as select *
+ |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows)
+ |union all
+ |select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows)
+ |union all
+ |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s
+ """.stripMargin)
+ sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile")
+ sql("set hive.exec.dynamic.partition=true")
+ sql("set hive.exec.dynamic.partition.mode=nonstrict")
+ sql("insert overwrite table sc_part partition(ts) select * from sc")
+ sql("drop table sc_part")
+ }
+
test("Partition spec validation") {
sql("DROP TABLE IF EXISTS dp_test")
sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)")
From c49a4049845c91b225e70fd630cdf6ddc055faf8 Mon Sep 17 00:00:00 2001
From: Miguel Peralvo
Date: Tue, 10 Feb 2015 19:54:52 +0000
Subject: [PATCH 043/817] [SPARK-5668] Display region in spark_ec2.py
get_existing_cluster()
Show the region for the different messages displayed by get_existing_cluster(): The search, found and error messages.
Author: Miguel Peralvo
Closes #4457 from MiguelPeralvo/patch-2 and squashes the following commits:
a5514c8 [Miguel Peralvo] Update spark_ec2.py
0a837b0 [Miguel Peralvo] Update spark_ec2.py
3923f36 [Miguel Peralvo] Update spark_ec2.py
4ecd9f9 [Miguel Peralvo] [SPARK-5668] Display region in spark_ec2.py get_existing_cluster()
---
ec2/spark_ec2.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index fe510f12bcec6..0ea7365d75b83 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -134,7 +134,7 @@ def parse_args():
help="Master instance type (leave empty for same as instance-type)")
parser.add_option(
"-r", "--region", default="us-east-1",
- help="EC2 region zone to launch instances in")
+ help="EC2 region used to launch instances in, or to find them in")
parser.add_option(
"-z", "--zone", default="",
help="Availability zone to launch instances in, or 'all' to spread " +
@@ -614,7 +614,8 @@ def launch_cluster(conn, opts, cluster_name):
# Get the EC2 instances in an existing cluster if available.
# Returns a tuple of lists of EC2 instance objects for the masters and slaves
def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
- print "Searching for existing cluster " + cluster_name + "..."
+ print "Searching for existing cluster " + cluster_name + " in region " \
+ + opts.region + "..."
reservations = conn.get_all_reservations()
master_nodes = []
slave_nodes = []
@@ -632,9 +633,11 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
return (master_nodes, slave_nodes)
else:
if master_nodes == [] and slave_nodes != []:
- print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master"
+ print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name \
+ + "-master" + " in region " + opts.region
else:
- print >> sys.stderr, "ERROR: Could not find any existing cluster"
+ print >> sys.stderr, "ERROR: Could not find any existing cluster" \
+ + " in region " + opts.region
sys.exit(1)
From de80b1ba4d3c4b1b3316d482d62e4668b996f6ac Mon Sep 17 00:00:00 2001
From: Michael Armbrust
Date: Tue, 10 Feb 2015 13:14:01 -0800
Subject: [PATCH 044/817] [SQL] Add toString to DataFrame/Column
Author: Michael Armbrust
Closes #4436 from marmbrus/dfToString and squashes the following commits:
8a3c35f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into dfToString
b72a81b [Michael Armbrust] add toString
---
python/pyspark/sql/dataframe.py | 2 +-
.../sql/catalyst/expressions/Expression.scala | 12 ++++++++
.../expressions/namedExpressions.scala | 20 +++++++++++++
.../org/apache/spark/sql/DataFrame.scala | 8 +++++
.../org/apache/spark/sql/DataFrameImpl.scala | 10 +++----
.../apache/spark/sql/IncomputableColumn.scala | 2 ++
.../spark/sql/execution/debug/package.scala | 11 ++++++-
.../org/apache/spark/sql/DataFrameSuite.scala | 29 +++++++++++++++++++
8 files changed, 86 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cda704eea75f5..04be65fe241c4 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -447,7 +447,7 @@ def selectExpr(self, *expr):
`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index cf14992ef835c..c32a4b886eb82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -66,6 +67,17 @@ abstract class Expression extends TreeNode[Expression] {
*/
def childrenResolved = !children.exists(!_.resolved)
+ /**
+ * Returns a string representation of this expression that does not have developer centric
+ * debugging information like the expression id.
+ */
+ def prettyString: String = {
+ transform {
+ case a: AttributeReference => PrettyAttribute(a.name)
+ case u: UnresolvedAttribute => PrettyAttribute(u.name)
+ }.toString
+ }
+
/**
* A set of helper functions that return the correct descendant of `scala.math.Numeric[T]` type
* and do any casting necessary of child evaluation.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index e6ab1fd8d7939..7f122e9d55734 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -190,6 +190,26 @@ case class AttributeReference(
override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
+/**
+ * A place holder used when printing expressions without debugging information such as the
+ * expression id or the unresolved indicator.
+ */
+case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
+ type EvaluatedType = Any
+
+ override def toString = name
+
+ override def withNullability(newNullability: Boolean): Attribute = ???
+ override def newInstance(): Attribute = ???
+ override def withQualifiers(newQualifiers: Seq[String]): Attribute = ???
+ override def withName(newName: String): Attribute = ???
+ override def qualifiers: Seq[String] = ???
+ override def exprId: ExprId = ???
+ override def eval(input: Row): EvaluatedType = ???
+ override def nullable: Boolean = ???
+ override def dataType: DataType = ???
+}
+
object VirtualColumn {
val groupingIdName = "grouping__id"
def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 6abfb7853cf1c..04e0d09947492 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
+import scala.util.control.NonFatal
+
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
@@ -92,6 +94,12 @@ trait DataFrame extends RDDApi[Row] {
*/
def toDataFrame: DataFrame = this
+ override def toString =
+ try schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") catch {
+ case NonFatal(e) =>
+ s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+ }
+
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 73393295ab0a5..1ee16ad5161c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -201,13 +201,11 @@ private[sql] class DataFrameImpl protected[sql](
override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
override def select(cols: Column*): DataFrame = {
- val exprs = cols.zipWithIndex.map {
- case (Column(expr: NamedExpression), _) =>
- expr
- case (Column(expr: Expression), _) =>
- Alias(expr, expr.toString)()
+ val namedExpressions = cols.map {
+ case Column(expr: NamedExpression) => expr
+ case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
- Project(exprs.toSeq, logicalPlan)
+ Project(namedExpressions.toSeq, logicalPlan)
}
override def select(col: String, cols: String*): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 0600dcc226b4d..ce0557b88196f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -40,6 +40,8 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn")
}
+ override def toString = expr.prettyString
+
override def isComputable: Boolean = false
override val sqlContext: SQLContext = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5cc67cdd13944..acef49aabfe70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -37,6 +37,15 @@ import org.apache.spark.sql.types._
*/
package object debug {
+ /**
+ * Augments [[SQLContext]] with debug methods.
+ */
+ implicit class DebugSQLContext(sqlContext: SQLContext) {
+ def debug() = {
+ sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
+ }
+ }
+
/**
* :: DeveloperApi ::
* Augments [[DataFrame]]s with debug methods.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5aa3db720c886..02623f73c7f76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.apache.spark.sql.TestData._
+
import scala.language.postfixOps
import org.apache.spark.sql.Dsl._
@@ -53,6 +55,33 @@ class DataFrameSuite extends QueryTest {
TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
}
+ test("dataframe toString") {
+ assert(testData.toString === "[key: int, value: string]")
+ assert(testData("key").toString === "[key: int]")
+ }
+
+ test("incomputable toString") {
+ assert($"test".toString === "test")
+ }
+
+ test("invalid plan toString, debug mode") {
+ val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis
+ TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+
+ // Turn on debug mode so we can see invalid query plans.
+ import org.apache.spark.sql.execution.debug._
+ TestSQLContext.debug()
+
+ val badPlan = testData.select('badColumn)
+
+ assert(badPlan.toString contains badPlan.queryExecution.toString,
+ "toString on bad query plans should include the query execution but was:\n" +
+ badPlan.toString)
+
+ // Set the flag back to original value before this test.
+ TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ }
+
test("table scan") {
checkAnswer(
testData,
From f98707c043f1be9569ec774796edb783132773a8 Mon Sep 17 00:00:00 2001
From: OopsOutOfMemory
Date: Tue, 10 Feb 2015 13:20:15 -0800
Subject: [PATCH 045/817] [SPARK-5686][SQL] Add show current roles command in
HiveQl
show current roles
Author: OopsOutOfMemory
Closes #4471 from OopsOutOfMemory/show_current_role and squashes the following commits:
1c6b210 [OopsOutOfMemory] add show current roles
---
sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 +
1 file changed, 1 insertion(+)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 969868aef2917..8618301ba84d6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -77,6 +77,7 @@ private[hive] object HiveQl {
"TOK_REVOKE",
"TOK_SHOW_GRANT",
"TOK_SHOW_ROLE_GRANT",
+ "TOK_SHOW_SET_ROLE",
"TOK_CREATEFUNCTION",
"TOK_DROPFUNCTION",
From fd2c032f95bbee342ca539df9e44927482981659 Mon Sep 17 00:00:00 2001
From: MechCoder
Date: Tue, 10 Feb 2015 14:05:55 -0800
Subject: [PATCH 046/817] [SPARK-5021] [MLlib] Gaussian Mixture now supports
Sparse Input
Following discussion in the Jira.
Author: MechCoder
Closes #4459 from MechCoder/sparse_gmm and squashes the following commits:
1b18dab [MechCoder] Rewrite syr for sparse matrices
e579041 [MechCoder] Add test for covariance matrix
5cb370b [MechCoder] Separate tests for sparse data
5e096bd [MechCoder] Alphabetize and correct error message
e180f4c [MechCoder] [SPARK-5021] Gaussian Mixture now supports Sparse Input
---
.../mllib/clustering/GaussianMixture.scala | 31 ++++-----
.../org/apache/spark/mllib/linalg/BLAS.scala | 36 +++++++++-
.../distribution/MultivariateGaussian.scala | 10 +--
.../clustering/GaussianMixtureSuite.scala | 66 +++++++++++++++++--
.../apache/spark/mllib/linalg/BLASSuite.scala | 8 +++
5 files changed, 125 insertions(+), 26 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 0be3014de862e..80584ef5e5979 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.IndexedSeq
-import breeze.linalg.{DenseMatrix => BreezeMatrix, DenseVector => BreezeVector, Transpose, diag}
+import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, SparseVector => BSV,
+ Transpose, Vector => BV}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Matrices, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, DenseMatrix, Matrices,
+ SparseVector, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -130,7 +132,7 @@ class GaussianMixture private (
val sc = data.sparkContext
// we will operate on the data as breeze data
- val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
+ val breezeData = data.map(_.toBreeze).cache()
// Get length of the input vectors
val d = breezeData.first().length
@@ -148,7 +150,7 @@ class GaussianMixture private (
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
- })
+ })
}
}
@@ -169,7 +171,7 @@ class GaussianMixture private (
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
- BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
+ BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
weights(i) = sums.weights(i) / sumWeights
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
@@ -185,8 +187,8 @@ class GaussianMixture private (
}
/** Average of dense breeze vectors */
- private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
- val v = BreezeVector.zeros[Double](x(0).length)
+ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
+ val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
v / x.length.toDouble
}
@@ -195,10 +197,10 @@ class GaussianMixture private (
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
*/
- private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
+ private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
val mu = vectorMean(x)
- val ss = BreezeVector.zeros[Double](x(0).length)
- x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
+ val ss = BDV.zeros[Double](x(0).length)
+ x.foreach(xi => ss += (xi - mu) :^ 2.0)
diag(ss / x.length.toDouble)
}
}
@@ -207,7 +209,7 @@ class GaussianMixture private (
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
new ExpectationSum(0.0, Array.fill(k)(0.0),
- Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+ Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
}
// compute cluster contributions for each input point
@@ -215,19 +217,18 @@ private object ExpectationSum {
def add(
weights: Array[Double],
dists: Array[MultivariateGaussian])
- (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
+ (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
}
val pSum = p.sum
sums.logLikelihood += math.log(pSum)
- val xxt = x * new Transpose(x)
var i = 0
while (i < sums.k) {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
- BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
+ BLAS.syr(p(i), Vectors.fromBreeze(x),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
i = i + 1
}
@@ -239,7 +240,7 @@ private object ExpectationSum {
private class ExpectationSum(
var logLikelihood: Double,
val weights: Array[Double],
- val means: Array[BreezeVector[Double]],
+ val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
val k = weights.length
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 079f7ca564a92..87052e1ba8539 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging {
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
- def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ def syr(alpha: Double, x: Vector, A: DenseMatrix) {
val mA = A.numRows
val nA = A.numCols
- require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
+ require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+ x match {
+ case dv: DenseVector => syr(alpha, dv, A)
+ case sv: SparseVector => syr(alpha, sv, A)
+ case _ =>
+ throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
+ }
+ }
+
+ private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ val nA = A.numRows
+ val mA = A.numCols
+
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
// Fill lower triangular part of A
@@ -255,6 +267,26 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
+ private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+ val mA = A.numCols
+ val xIndices = x.indices
+ val xValues = x.values
+ val nnz = xValues.length
+ val Avalues = A.values
+
+ var i = 0
+ while (i < nnz) {
+ val multiplier = alpha * xValues(i)
+ val offset = xIndices(i) * mA
+ var j = 0
+ while (j < nnz) {
+ Avalues(xIndices(j) + offset) += multiplier * xValues(j)
+ j += 1
+ }
+ i += 1
+ }
+ }
+
/**
* C := alpha * A * B + beta * C
* @param alpha a scalar to scale the multiplication A * B.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index fd186b5ee6f72..cd6add9d60b0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.stat.distribution
-import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV}
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
@@ -62,21 +62,21 @@ class MultivariateGaussian (
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
- pdf(x.toBreeze.toDenseVector)
+ pdf(x.toBreeze)
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
- logpdf(x.toBreeze.toDenseVector)
+ logpdf(x.toBreeze)
}
/** Returns density of this multivariate Gaussian at given point, x */
- private[mllib] def pdf(x: DBV[Double]): Double = {
+ private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
- private[mllib] def logpdf(x: DBV[Double]): Double = {
+ private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index c2cd56ea40adc..1b46a4012d731 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -31,7 +31,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(5.0, 10.0),
Vectors.dense(4.0, 11.0)
))
-
+
// expectations
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
@@ -44,6 +44,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
+
}
test("two clusters") {
@@ -54,7 +55,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
-
+
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
@@ -63,7 +64,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
)
)
-
+
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
@@ -72,7 +73,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
.setK(2)
.setInitialModel(initialGmm)
.run(data)
-
+
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
@@ -80,4 +81,61 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("single cluster with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)),
+ Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)),
+ Vectors.sparse(3, Array(1), Array(6.0))
+ ))
+
+ val Ew = 1.0
+ val Emu = Vectors.dense(2.0, 2.0, 2.0)
+ val Esigma = Matrices.dense(3, 3,
+ Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0)
+ )
+
+ val seeds = Array(42, 1994, 27, 11, 0)
+ seeds.foreach { seed =>
+ val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
+ }
+ }
+
+ test("two clusters with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ ))
+
+ val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
+ // we set an initial gaussian to induce expected results
+ val initialGmm = new GaussianMixtureModel(
+ Array(0.5, 0.5),
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
+ )
+ val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
+ val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
+ val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
+
+ val sparseGMM = new GaussianMixture()
+ .setK(2)
+ .setInitialModel(initialGmm)
+ .run(data)
+
+ assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0b78acd6df16..002cb253862b5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -166,6 +166,14 @@ class BLASSuite extends FunSuite {
syr(alpha, y, dA)
}
}
+
+ val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0))
+ val dD = new DenseMatrix(4, 4,
+ Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
+ syr(0.1, xSparse, dD)
+ val expectedSparse = new DenseMatrix(4, 4,
+ Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4))
+ assert(dD ~== expectedSparse absTol 1e-15)
}
test("gemm") {
From 5820961289eb98e45eb467efa316c7592b8d619c Mon Sep 17 00:00:00 2001
From: Brennon York
Date: Tue, 10 Feb 2015 14:57:00 -0800
Subject: [PATCH 047/817] [SPARK-5343][GraphX]: ShortestPaths traverses
backwards
Corrected the logic with ShortestPaths so that the calculation will run forward rather than backwards. Output before looked like:
```scala
import org.apache.spark.graphx._
val g = Graph(sc.makeRDD(Array((1L,""), (2L,""), (3L,""))), sc.makeRDD(Array(Edge(1L,2L,""), Edge(2L,3L,""))))
lib.ShortestPaths.run(g,Array(3)).vertices.collect
// res0: Array[(org.apache.spark.graphx.VertexId, org.apache.spark.graphx.lib.ShortestPaths.SPMap)] = Array((1,Map()), (3,Map(3 -> 0)), (2,Map()))
lib.ShortestPaths.run(g,Array(1)).vertices.collect
// res1: Array[(org.apache.spark.graphx.VertexId, org.apache.spark.graphx.lib.ShortestPaths.SPMap)] = Array((1,Map(1 -> 0)), (3,Map(1 -> 2)), (2,Map(1 -> 1)))
```
And new output after the changes looks like:
```scala
import org.apache.spark.graphx._
val g = Graph(sc.makeRDD(Array((1L,""), (2L,""), (3L,""))), sc.makeRDD(Array(Edge(1L,2L,""), Edge(2L,3L,""))))
lib.ShortestPaths.run(g,Array(3)).vertices.collect
// res0: Array[(org.apache.spark.graphx.VertexId, org.apache.spark.graphx.lib.ShortestPaths.SPMap)] = Array((1,Map(3 -> 2)), (2,Map(3 -> 1)), (3,Map(3 -> 0)))
lib.ShortestPaths.run(g,Array(1)).vertices.collect
// res1: Array[(org.apache.spark.graphx.VertexId, org.apache.spark.graphx.lib.ShortestPaths.SPMap)] = Array((1,Map(1 -> 0)), (2,Map()), (3,Map()))
```
Author: Brennon York
Closes #4478 from brennonyork/SPARK-5343 and squashes the following commits:
aa57f83 [Brennon York] updated to set ShortestPaths to run 'forward' rather than 'backward'
---
.../scala/org/apache/spark/graphx/lib/ShortestPaths.scala | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
index 590f0474957dd..179f2843818e0 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
@@ -61,8 +61,8 @@ object ShortestPaths {
}
def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
- val newAttr = incrementMap(edge.srcAttr)
- if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr))
+ val newAttr = incrementMap(edge.dstAttr)
+ if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
else Iterator.empty
}
From 52983d7f4f1a155433b6df3687cf5dc71804cfd5 Mon Sep 17 00:00:00 2001
From: Sephiroth-Lin
Date: Tue, 10 Feb 2015 23:23:35 +0000
Subject: [PATCH 048/817] [SPARK-5644] [Core]Delete tmp dir when sc is stop
When we run driver as a service, and for each time we run job we only call sc.stop, then will not delete tmp dir create by HttpFileServer and SparkEnv, it will be deleted until the service process exit, so we need to delete these tmp dirs when sc is stop directly.
Author: Sephiroth-Lin
Closes #4412 from Sephiroth-Lin/bug-fix-master-01 and squashes the following commits:
fbbc785 [Sephiroth-Lin] using an interpolated string
b968e14 [Sephiroth-Lin] using an interpolated string
4edf394 [Sephiroth-Lin] rename the variable and update comment
1339c96 [Sephiroth-Lin] add a member to store the reference of tmp dir
b2018a5 [Sephiroth-Lin] check sparkFilesDir before delete
f48a3c6 [Sephiroth-Lin] don't check sparkFilesDir, check executorId
dd9686e [Sephiroth-Lin] format code
b38e0f0 [Sephiroth-Lin] add dir check before delete
d7ccc64 [Sephiroth-Lin] Change log level
1d70926 [Sephiroth-Lin] update comment
e2a2b1b [Sephiroth-Lin] update comment
aeac518 [Sephiroth-Lin] Delete tmp dir when sc is stop
c0d5b28 [Sephiroth-Lin] Delete tmp dir when sc is stop
---
.../org/apache/spark/HttpFileServer.scala | 9 ++++++
.../scala/org/apache/spark/SparkEnv.scala | 29 ++++++++++++++++++-
2 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 3f33332a81eaf..7e706bcc42f04 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -50,6 +50,15 @@ private[spark] class HttpFileServer(
def stop() {
httpServer.stop()
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
+ // the tmp dir, if not, it will create too many tmp dirs
+ try {
+ Utils.deleteRecursively(baseDir)
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e)
+ }
}
def addFile(file: File) : String = {
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index f25db7f8de565..b63bea5b102b6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -76,6 +76,8 @@ class SparkEnv (
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
+ private var driverTmpDirToDelete: Option[String] = None
+
private[spark] def stop() {
isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -93,6 +95,22 @@ class SparkEnv (
// actorSystem.awaitTermination()
// Note that blockTransferService is stopped by BlockManager since it is started by it.
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
+ // the tmp dir, if not, it will create too many tmp dirs.
+ // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
+ // current working dir in executor which we do not need to delete.
+ driverTmpDirToDelete match {
+ case Some(path) => {
+ try {
+ Utils.deleteRecursively(new File(path))
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception while deleting Spark temp dir: $path", e)
+ }
+ }
+ case None => // We just need to delete tmp dir created by driver, so do nothing on executor
+ }
}
private[spark]
@@ -350,7 +368,7 @@ object SparkEnv extends Logging {
"levels using the RDD.persist() method instead.")
}
- new SparkEnv(
+ val envInstance = new SparkEnv(
executorId,
actorSystem,
serializer,
@@ -367,6 +385,15 @@ object SparkEnv extends Logging {
metricsSystem,
shuffleMemoryManager,
conf)
+
+ // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
+ // called, and we only need to do it for driver. Because driver may run as a service, and if we
+ // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs.
+ if (isDriver) {
+ envInstance.driverTmpDirToDelete = Some(sparkFilesDir)
+ }
+
+ envInstance
}
/**
From 91e3512544d9ab684799ac9a9c341ab465e1b427 Mon Sep 17 00:00:00 2001
From: "Sheng, Li"
Date: Wed, 11 Feb 2015 00:59:46 +0000
Subject: [PATCH 049/817] [SQL][Minor] correct some comments
Author: Sheng, Li
Author: OopsOutOfMemory
Closes #4508 from OopsOutOfMemory/cmt and squashes the following commits:
d8a68c6 [Sheng, Li] Update ddl.scala
f24aeaf [OopsOutOfMemory] correct style
---
sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index a692ef51b31ed..bf2ad14763e9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -141,7 +141,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
/*
* describe [extended] table avroTable
- * This will display all columns of table `avroTable` includes column_name,column_type,nullable
+ * This will display all columns of table `avroTable` includes column_name,column_type,comment
*/
protected lazy val describeTable: Parser[LogicalPlan] =
(DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
From 2d50a010ff57a861b13c2088ac048662d535f5e7 Mon Sep 17 00:00:00 2001
From: Cheng Lian
Date: Tue, 10 Feb 2015 17:02:44 -0800
Subject: [PATCH 050/817] [SPARK-5725] [SQL] Fixes ParquetRelation2.equals
[
](https://reviewable.io/reviews/apache/spark/4513)
Author: Cheng Lian
Closes #4513 from liancheng/spark-5725 and squashes the following commits:
bf6a087 [Cheng Lian] Fixes ParquetRelation2.equals
---
.../src/main/scala/org/apache/spark/sql/parquet/newParquet.scala | 1 +
1 file changed, 1 insertion(+)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 49d46334b6525..04804f78f5c34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -152,6 +152,7 @@ case class ParquetRelation2
paths.toSet == relation.paths.toSet &&
maybeMetastoreSchema == relation.maybeMetastoreSchema &&
(shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema)
+ case _ => false
}
private[sql] def sparkContext = sqlContext.sparkContext
From e28b6bdbb5c5e4fd62ec0b547b77719c3f7e476e Mon Sep 17 00:00:00 2001
From: Yin Huai
Date: Tue, 10 Feb 2015 17:06:12 -0800
Subject: [PATCH 051/817] [SQL] Make Options in the data source API CREATE
TABLE statements optional.
Users will not need to put `Options()` in a CREATE TABLE statement when there is not option provided.
Author: Yin Huai
Closes #4515 from yhuai/makeOptionsOptional and squashes the following commits:
1a898d3 [Yin Huai] Make options optional.
---
.../src/main/scala/org/apache/spark/sql/sources/ddl.scala | 7 ++++---
.../apache/spark/sql/hive/MetastoreDataSourcesSuite.scala | 4 +---
2 files changed, 5 insertions(+), 6 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index bf2ad14763e9f..9f64f761002c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -106,13 +106,14 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
protected lazy val createTable: Parser[LogicalPlan] =
(
(CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident
- ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ~ (AS ~> restInput).? ^^ {
+ ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
if (temp.isDefined && allowExisting.isDefined) {
throw new DDLException(
"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
}
+ val options = opts.getOrElse(Map.empty[String, String])
if (query.isDefined) {
if (columns.isDefined) {
throw new DDLException(
@@ -121,7 +122,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
CreateTableUsingAsSelect(tableName,
provider,
temp.isDefined,
- opts,
+ options,
allowExisting.isDefined,
query.get)
} else {
@@ -131,7 +132,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
userSpecifiedSchema,
provider,
temp.isDefined,
- opts,
+ options,
allowExisting.isDefined)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 036efa84d7c85..9ce058909f429 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -361,9 +361,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
s"""
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
- |OPTIONS (
- |
- |) AS
+ |AS
|SELECT * FROM jsonTable
""".stripMargin)
From ed167e70c6d355f39b366ea0d3b92dd26d826a0b Mon Sep 17 00:00:00 2001
From: Marcelo Vanzin
Date: Tue, 10 Feb 2015 17:19:10 -0800
Subject: [PATCH 052/817] [SPARK-5493] [core] Add option to impersonate user.
Hadoop has a feature that allows users to impersonate other users
when submitting applications or talking to HDFS, for example. These
impersonated users are referred generally as "proxy users".
Services such as Oozie or Hive use this feature to run applications
as the requesting user.
This change makes SparkSubmit accept a new command line option to
run the application as a proxy user. It also fixes the plumbing
of the user name through the UI (and a couple of other places) to
refer to the correct user running the application, which can be
different than `sys.props("user.name")` even without proxies (e.g.
when using kerberos).
Author: Marcelo Vanzin
Closes #4405 from vanzin/SPARK-5493 and squashes the following commits:
df82427 [Marcelo Vanzin] Clarify the reason for the special exception handling.
05bfc08 [Marcelo Vanzin] Remove unneeded annotation.
4840de9 [Marcelo Vanzin] Review feedback.
8af06ff [Marcelo Vanzin] Fix usage string.
2e4fa8f [Marcelo Vanzin] Merge branch 'master' into SPARK-5493
b6c947d [Marcelo Vanzin] Merge branch 'master' into SPARK-5493
0540d38 [Marcelo Vanzin] [SPARK-5493] [core] Add option to impersonate user.
---
bin/utils.sh | 3 +-
bin/windows-utils.cmd | 1 +
.../org/apache/spark/SecurityManager.scala | 3 +-
.../scala/org/apache/spark/SparkContext.scala | 16 ++----
.../apache/spark/deploy/SparkHadoopUtil.scala | 19 +++----
.../org/apache/spark/deploy/SparkSubmit.scala | 56 ++++++++++++++++---
.../spark/deploy/SparkSubmitArguments.scala | 7 +++
.../scala/org/apache/spark/util/Utils.scala | 11 ++++
8 files changed, 82 insertions(+), 34 deletions(-)
diff --git a/bin/utils.sh b/bin/utils.sh
index 2241200082018..748dbe345a74c 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -35,7 +35,8 @@ function gatherSparkSubmitOpts() {
--master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \
--conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
- --total-executor-cores | --executor-cores | --queue | --num-executors | --archives)
+ --total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \
+ --proxy-user)
if [[ $# -lt 2 ]]; then
"$SUBMIT_USAGE_FUNCTION"
exit 1;
diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd
index 567b8733f7f77..0cf9e87ca554b 100644
--- a/bin/windows-utils.cmd
+++ b/bin/windows-utils.cmd
@@ -33,6 +33,7 @@ SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--
SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"
+SET opts="%opts:~1,-1% \<--proxy-user\>"
echo %1 | findstr %opts% >nul
if %ERRORLEVEL% equ 0 (
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 88d35a4bacc6e..3653f724ba192 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.network.sasl.SecretKeyHolder
+import org.apache.spark.util.Utils
/**
* Spark class responsible for security.
@@ -203,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
// always add the current user and SPARK_USER to the viewAcls
private val defaultAclUsers = Set[String](System.getProperty("user.name", ""),
- Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty)
+ Utils.getCurrentUserName())
setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))
setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", ""))
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 04ca5d1019e4b..53fce6b0defdf 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -191,7 +191,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// log out Spark Version in Spark driver log
logInfo(s"Running Spark version $SPARK_VERSION")
-
+
private[spark] val conf = config.clone()
conf.validateSettings()
@@ -335,11 +335,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
executorEnvs ++= conf.getExecutorEnv
// Set SPARK_USER for user who is running SparkContext.
- val sparkUser = Option {
- Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name"))
- }.getOrElse {
- SparkContext.SPARK_UNKNOWN_USER
- }
+ val sparkUser = Utils.getCurrentUserName()
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
@@ -826,7 +822,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
assertNotStopped()
- // The call to new NewHadoopJob automatically adds security credentials to conf,
+ // The call to new NewHadoopJob automatically adds security credentials to conf,
// so we don't need to explicitly add them ourselves
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
@@ -1626,8 +1622,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /**
- * Default min number of partitions for Hadoop RDDs when not given by user
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
* Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
* The reasons for this are discussed in https://github.com/mesos/spark/pull/718
*/
@@ -1844,8 +1840,6 @@ object SparkContext extends Logging {
private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"
- private[spark] val SPARK_UNKNOWN_USER = ""
-
private[spark] val DRIVER_IDENTIFIER = ""
// The following deprecated objects have already been copied to `object AccumulatorParam` to
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 03238e9fa0088..e0a32fb65cd51 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -52,18 +52,13 @@ class SparkHadoopUtil extends Logging {
* do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems
*/
def runAsSparkUser(func: () => Unit) {
- val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
- if (user != SparkContext.SPARK_UNKNOWN_USER) {
- logDebug("running as user: " + user)
- val ugi = UserGroupInformation.createRemoteUser(user)
- transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
- ugi.doAs(new PrivilegedExceptionAction[Unit] {
- def run: Unit = func()
- })
- } else {
- logDebug("running as SPARK_UNKNOWN_USER")
- func()
- }
+ val user = Utils.getCurrentUserName()
+ logDebug("running as user: " + user)
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
}
def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index c4bc5054d61a1..80cc0587286b1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -18,12 +18,14 @@
package org.apache.spark.deploy
import java.io.{File, PrintStream}
-import java.lang.reflect.{InvocationTargetException, Modifier}
+import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
+import java.security.PrivilegedExceptionAction
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
import org.apache.ivy.core.module.descriptor._
@@ -79,7 +81,7 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var exitFn: () => Unit = () => System.exit(-1)
+ private[spark] var exitFn: () => Unit = () => System.exit(1)
private[spark] var printStream: PrintStream = System.err
private[spark] def printWarning(str: String) = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String) = {
@@ -126,6 +128,34 @@ object SparkSubmit {
*/
private[spark] def submit(args: SparkSubmitArguments): Unit = {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
+
+ def doRunMain(): Unit = {
+ if (args.proxyUser != null) {
+ val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser,
+ UserGroupInformation.getCurrentUser())
+ try {
+ proxyUser.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
+ }
+ })
+ } catch {
+ case e: Exception =>
+ // Hadoop's AuthorizationException suppresses the exception's stack trace, which
+ // makes the message printed to the output by the JVM not very helpful. Instead,
+ // detect exceptions with empty stack traces here, and treat them differently.
+ if (e.getStackTrace().length == 0) {
+ printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
+ exitFn()
+ } else {
+ throw e
+ }
+ }
+ } else {
+ runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
+ }
+ }
+
// In standalone cluster mode, there are two submission gateways:
// (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
// (2) The new REST-based gateway introduced in Spark 1.3
@@ -134,7 +164,7 @@ object SparkSubmit {
if (args.isStandaloneCluster && args.useRest) {
try {
printStream.println("Running Spark using the REST application submission protocol.")
- runMain(childArgs, childClasspath, sysProps, childMainClass)
+ doRunMain()
} catch {
// Fail over to use the legacy submission gateway
case e: SubmitRestConnectionException =>
@@ -145,7 +175,7 @@ object SparkSubmit {
}
// In all other modes, just run the main class as prepared
} else {
- runMain(childArgs, childClasspath, sysProps, childMainClass)
+ doRunMain()
}
}
@@ -457,7 +487,7 @@ object SparkSubmit {
childClasspath: Seq[String],
sysProps: Map[String, String],
childMainClass: String,
- verbose: Boolean = false) {
+ verbose: Boolean): Unit = {
if (verbose) {
printStream.println(s"Main class:\n$childMainClass")
printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
@@ -507,13 +537,21 @@ object SparkSubmit {
if (!Modifier.isStatic(mainMethod.getModifiers)) {
throw new IllegalStateException("The main method in the given main class must be static")
}
+
+ def findCause(t: Throwable): Throwable = t match {
+ case e: UndeclaredThrowableException =>
+ if (e.getCause() != null) findCause(e.getCause()) else e
+ case e: InvocationTargetException =>
+ if (e.getCause() != null) findCause(e.getCause()) else e
+ case e: Throwable =>
+ e
+ }
+
try {
mainMethod.invoke(null, childArgs.toArray)
} catch {
- case e: InvocationTargetException => e.getCause match {
- case cause: Throwable => throw cause
- case null => throw e
- }
+ case t: Throwable =>
+ throw findCause(t)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index bd0ae26fd8210..fa38070c6fcfe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -57,6 +57,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var pyFiles: String = null
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
+ var proxyUser: String = null
// Standalone cluster mode only
var supervise: Boolean = false
@@ -405,6 +406,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
parse(tail)
+ case ("--proxy-user") :: value :: tail =>
+ proxyUser = value
+ parse(tail)
+
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
@@ -476,6 +481,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
|
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
|
+ | --proxy-user NAME User to impersonate when submitting the application.
+ |
| --help, -h Show this help message and exit
| --verbose, -v Print additional debug output
|
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 61d287ca9c3ac..6af8dd555f2aa 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -38,6 +38,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._
@@ -1986,6 +1987,16 @@ private[spark] object Utils extends Logging {
throw new SparkException("Invalid master URL: " + sparkUrl, e)
}
}
+
+ /**
+ * Returns the current user name. This is the currently logged in user, unless that's been
+ * overridden by the `SPARK_USER` environment variable.
+ */
+ def getCurrentUserName(): String = {
+ Option(System.getenv("SPARK_USER"))
+ .getOrElse(UserGroupInformation.getCurrentUser().getUserName())
+ }
+
}
/**
From aaf50d05c7616e4f8f16654b642500ae06cdd774 Mon Sep 17 00:00:00 2001
From: Yin Huai
Date: Tue, 10 Feb 2015 17:29:52 -0800
Subject: [PATCH 053/817] [SPARK-5658][SQL] Finalize DDL and write support APIs
https://issues.apache.org/jira/browse/SPARK-5658
Author: Yin Huai
This patch had conflicts when merged, resolved by
Committer: Michael Armbrust
Closes #4446 from yhuai/writeSupportFollowup and squashes the following commits:
f3a96f7 [Yin Huai] davies's comments.
225ff71 [Yin Huai] Use Scala TestHiveContext to initialize the Python HiveContext in Python tests.
2306f93 [Yin Huai] Style.
2091fcd [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
537e28f [Yin Huai] Correctly clean up temp data.
ae4649e [Yin Huai] Fix Python test.
609129c [Yin Huai] Doc format.
92b6659 [Yin Huai] Python doc and other minor updates.
cbc717f [Yin Huai] Rename dataSourceName to source.
d1c12d3 [Yin Huai] No need to delete the duplicate rule since it has been removed in master.
22cfa70 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
d91ecb8 [Yin Huai] Fix test.
4c76d78 [Yin Huai] Simplify APIs.
3abc215 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
0832ce4 [Yin Huai] Fix test.
98e7cdb [Yin Huai] Python style.
2bf44ef [Yin Huai] Python APIs.
c204967 [Yin Huai] Format
a10223d [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
9ff97d8 [Yin Huai] Add SaveMode to saveAsTable.
9b6e570 [Yin Huai] Update doc.
c2be775 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
99950a2 [Yin Huai] Use Java enum for SaveMode.
4679665 [Yin Huai] Remove duplicate rule.
77d89dc [Yin Huai] Update doc.
e04d908 [Yin Huai] Move import and add (Scala-specific) to scala APIs.
cf5703d [Yin Huai] Add checkAnswer to Java tests.
7db95ff [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
6dfd386 [Yin Huai] Add java test.
f2f33ef [Yin Huai] Fix test.
e702386 [Yin Huai] Apache header.
b1e9b1b [Yin Huai] Format.
ed4e1b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
af9e9b3 [Yin Huai] DDL and write support API followup.
2a6213a [Yin Huai] Update API names.
e6a0b77 [Yin Huai] Update test.
43bae01 [Yin Huai] Remove createTable from HiveContext.
5ffc372 [Yin Huai] Add more load APIs to SQLContext.
5390743 [Yin Huai] Add more save APIs to DataFrame.
---
python/pyspark/sql/context.py | 68 ++++++++
python/pyspark/sql/dataframe.py | 72 +++++++-
python/pyspark/sql/tests.py | 107 +++++++++++-
.../apache/spark/sql/sources/SaveMode.java | 45 +++++
.../org/apache/spark/sql/DataFrame.scala | 160 ++++++++++++++---
.../org/apache/spark/sql/DataFrameImpl.scala | 61 ++-----
.../apache/spark/sql/IncomputableColumn.scala | 27 +--
.../scala/org/apache/spark/sql/SQLConf.scala | 2 +-
.../org/apache/spark/sql/SQLContext.scala | 164 +++++++++++++++++-
.../spark/sql/execution/SparkStrategies.scala | 14 +-
.../apache/spark/sql/json/JSONRelation.scala | 30 +++-
.../apache/spark/sql/parquet/newParquet.scala | 45 ++++-
.../org/apache/spark/sql/sources/ddl.scala | 40 ++++-
.../apache/spark/sql/sources/interfaces.scala | 19 ++
.../spark/sql/sources/JavaSaveLoadSuite.java | 97 +++++++++++
.../org/apache/spark/sql/QueryTest.scala | 92 ++++++----
.../sources/CreateTableAsSelectSuite.scala | 29 +++-
.../spark/sql/sources/SaveLoadSuite.scala | 59 +++++--
.../apache/spark/sql/hive/HiveContext.scala | 76 --------
.../spark/sql/hive/HiveStrategies.scala | 13 +-
.../spark/sql/hive/execution/commands.scala | 105 ++++++++---
.../spark/sql/hive/{ => test}/TestHive.scala | 20 +--
.../hive/JavaMetastoreDataSourcesSuite.java | 147 ++++++++++++++++
.../org/apache/spark/sql/QueryTest.scala | 64 +++++--
.../sql/hive/InsertIntoHiveTableSuite.scala | 33 ++--
.../sql/hive/MetastoreDataSourcesSuite.scala | 118 +++++++++++--
26 files changed, 1357 insertions(+), 350 deletions(-)
create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
create mode 100644 sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{ => test}/TestHive.scala (99%)
create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 49f016a9cf2e9..882c0f98ea40b 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -21,6 +21,7 @@
from itertools import imap
from py4j.protocol import Py4JError
+from py4j.java_collections import MapConverter
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -87,6 +88,18 @@ def _ssql_ctx(self):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
+ def setConf(self, key, value):
+ """Sets the given Spark SQL configuration property.
+ """
+ self._ssql_ctx.setConf(key, value)
+
+ def getConf(self, key, defaultValue):
+ """Returns the value of Spark SQL configuration property for the given key.
+
+ If the key is not set, returns defaultValue.
+ """
+ return self._ssql_ctx.getConf(key, defaultValue)
+
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -455,6 +468,61 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
+ def load(self, path=None, source=None, schema=None, **options):
+ """Returns the dataset in a data source as a DataFrame.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.load(source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.load(source, scala_datatype, joptions)
+ return DataFrame(df, self)
+
+ def createExternalTable(self, tableName, path=None, source=None,
+ schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame and
+ created external table.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
+ joptions)
+ return DataFrame(df, self)
+
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 04be65fe241c4..3eef0cc376a2d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -146,9 +146,75 @@ def insertInto(self, tableName, overwrite=False):
"""
self._jdf.insertInto(tableName, overwrite)
- def saveAsTable(self, tableName):
- """Creates a new table with the contents of this DataFrame."""
- self._jdf.saveAsTable(tableName)
+ def _java_save_mode(self, mode):
+ """Returns the Java save mode based on the Python save mode represented by a string.
+ """
+ jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode
+ jmode = jSaveMode.ErrorIfExists
+ mode = mode.lower()
+ if mode == "append":
+ jmode = jSaveMode.Append
+ elif mode == "overwrite":
+ jmode = jSaveMode.Overwrite
+ elif mode == "ignore":
+ jmode = jSaveMode.Ignore
+ elif mode == "error":
+ pass
+ else:
+ raise ValueError(
+ "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
+ return jmode
+
+ def saveAsTable(self, tableName, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source as a table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the saveAsTable operation when
+ table already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by the contents of \
+ this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing table.
+ """
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self.sql_ctx._sc._gateway._gateway_client)
+ self._jdf.saveAsTable(tableName, source, jmode, joptions)
+
+ def save(self, path=None, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the save operation when
+ data already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing data.
+ * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing data.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ self._jdf.save(source, jmode, joptions)
def schema(self):
"""Returns the schema of this DataFrame (represented by
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d25c6365ed067..bc945091f7042 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -34,10 +34,9 @@
else:
import unittest
-
-from pyspark.sql import SQLContext, Column
+from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType
+ UserDefinedType, DoubleType, LongType, StringType
from pyspark.tests import ReusedPySparkTestCase
@@ -286,6 +285,37 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ def test_save_and_load(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.save(tmpPath, "org.apache.spark.sql.json", "error")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+
+ df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
+ noUse="this options will not be used in save.")
+ actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
+ noUse="this options will not be used in load.")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -296,5 +326,76 @@ def test_help_command(self):
pydoc.render_doc(df.take(1))
+class HiveContextSQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ print "type", type(cls.sc)
+ print "type", type(cls.sc._jsc)
+ _scala_HiveContext =\
+ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
+ cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_save_and_load_table(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
+ "org.apache.spark.sql.json")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.createExternalTable("externalJsonTable",
+ source="org.apache.spark.sql.json",
+ schema=schema, path=tmpPath,
+ noUse="this options will not be used")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.select("value").collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
if __name__ == "__main__":
unittest.main()
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
new file mode 100644
index 0000000000000..3109f5716da2c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources;
+
+/**
+ * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
+ */
+public enum SaveMode {
+ /**
+ * Append mode means that when saving a DataFrame to a data source, if data/table already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ */
+ Append,
+ /**
+ * Overwrite mode means that when saving a DataFrame to a data source,
+ * if data/table already exists, existing data is expected to be overwritten by the contents of
+ * the DataFrame.
+ */
+ Overwrite,
+ /**
+ * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists,
+ * an exception is expected to be thrown.
+ */
+ ErrorIfExists,
+ /**
+ * Ignore mode means that when saving a DataFrame to a data source, if data already exists,
+ * the save operation is expected to not save the contents of the DataFrame and to not
+ * change the existing data.
+ */
+ Ignore
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 04e0d09947492..ca8d552c5febf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -17,19 +17,19 @@
package org.apache.spark.sql
+import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.sources.SaveMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-import scala.util.control.NonFatal
-
-
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrameImpl(sqlContext, logicalPlan)
@@ -574,8 +574,64 @@ trait DataFrame extends RDDApi[Row] {
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame. This will fail if the table already
- * exists.
+ * Creates a table from the the contents of this DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ * This will fail if the table already exists.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(tableName: String): Unit = {
+ saveAsTable(tableName, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table from the the contents of this DataFrame, using the default data source
+ * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(tableName: String, mode: SaveMode): Unit = {
+ if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) {
+ // If table already exists and the save mode is Append,
+ // we will just call insertInto to append the contents of this DataFrame.
+ insertInto(tableName, overwrite = false)
+ } else {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ saveAsTable(tableName, dataSourceName, mode)
+ }
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source and a set of options,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ def saveAsTable(
+ tableName: String,
+ source: String): Unit = {
+ saveAsTable(tableName, source, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -583,12 +639,17 @@ trait DataFrame extends RDDApi[Row] {
* be the target of an `insertInto`.
*/
@Experimental
- def saveAsTable(tableName: String): Unit
+ def saveAsTable(
+ tableName: String,
+ source: String,
+ mode: SaveMode): Unit = {
+ saveAsTable(tableName, source, mode, Map.empty[String, String])
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -598,14 +659,17 @@ trait DataFrame extends RDDApi[Row] {
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ saveAsTable(tableName, source, mode, options.toMap)
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * (Scala-specific)
+ * Creates a table from the the contents of this DataFrame based on a given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
@@ -615,22 +679,76 @@ trait DataFrame extends RDDApi[Row] {
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path,
+ * using the default data source configured by spark.sql.sources.default and
+ * [[SaveMode.ErrorIfExists]] as the save mode.
+ */
+ @Experimental
+ def save(path: String): Unit = {
+ save(path, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode,
+ * using the default data source configured by spark.sql.sources.default.
+ */
+ @Experimental
+ def save(path: String, mode: SaveMode): Unit = {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ save(path, dataSourceName, mode)
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
+ */
+ @Experimental
+ def save(path: String, source: String): Unit = {
+ save(source, SaveMode.ErrorIfExists, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source and
+ * [[SaveMode]] specified by mode.
+ */
@Experimental
- def save(path: String): Unit
+ def save(path: String, source: String, mode: SaveMode): Unit = {
+ save(source, mode, Map("path" -> path))
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
+ */
@Experimental
def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ save(source, mode, options.toMap)
+ }
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options
+ */
@Experimental
def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 1ee16ad5161c8..11f9334556981 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -28,13 +28,14 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan}
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}
@@ -341,68 +342,34 @@ private[sql] class DataFrameImpl protected[sql](
override def saveAsParquetFile(path: String): Unit = {
if (sqlContext.conf.parquetUseDataSourceApi) {
- save("org.apache.spark.sql.parquet", "path" -> path)
+ save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path))
} else {
sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
}
}
- override def saveAsTable(tableName: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- val cmd =
- CreateTableUsingAsLogicalPlan(
- tableName,
- dataSourceName,
- temporary = false,
- Map.empty,
- allowExisting = false,
- logicalPlan)
-
- sqlContext.executePlan(cmd).toRdd
- }
-
override def saveAsTable(
tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
val cmd =
CreateTableUsingAsLogicalPlan(
tableName,
- dataSourceName,
+ source,
temporary = false,
- (option +: options).toMap,
- allowExisting = false,
+ mode,
+ options,
logicalPlan)
sqlContext.executePlan(cmd).toRdd
}
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*)
- }
-
- override def save(path: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- save(dataSourceName, "path" -> path)
- }
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
- ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this)
- }
-
override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- save(dataSourceName, opts.head, opts.tail:_*)
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
+ ResolvedDataSource(sqlContext, source, mode, options, this)
}
override def insertInto(tableName: String, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index ce0557b88196f..494e49c1317b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedSt
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.sources.SaveMode
import org.apache.spark.sql.types.StructType
-
private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column {
def this(name: String) = this(name match {
@@ -156,29 +156,16 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def saveAsParquetFile(path: String): Unit = err()
- override def saveAsTable(tableName: String): Unit = err()
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
-
override def saveAsTable(
tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
-
- override def save(path: String): Unit = err()
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = err()
override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = err()
override def insertInto(tableName: String, overwrite: Boolean): Unit = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 180f5e765fb91..39f6c2f4bc8b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -50,7 +50,7 @@ private[spark] object SQLConf {
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
// This is used to set the default data source
- val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource"
+ val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
// Whether to perform eager analysis on a DataFrame.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 97e3777f933e4..801505bceb956 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -401,27 +401,173 @@ class SQLContext(@transient val sparkContext: SparkContext)
jsonRDD(json.rdd, samplingRatio);
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame,
+ * using the default data source configured by spark.sql.sources.default.
+ */
@Experimental
def load(path: String): DataFrame = {
val dataSourceName = conf.defaultDataSourceName
- load(dataSourceName, ("path", path))
+ load(path, dataSourceName)
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame,
+ * using the given data source.
+ */
@Experimental
- def load(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): DataFrame = {
- val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap)
+ def load(path: String, source: String): DataFrame = {
+ load(source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame.
+ */
+ @Experimental
+ def load(source: String, options: java.util.Map[String, String]): DataFrame = {
+ load(source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame.
+ */
+ @Experimental
+ def load(source: String, options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, None, source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame,
+ * using the given schema as the schema of the DataFrame.
+ */
@Experimental
def load(
- dataSourceName: String,
+ source: String,
+ schema: StructType,
options: java.util.Map[String, String]): DataFrame = {
- val opts = options.toSeq
- load(dataSourceName, opts.head, opts.tail:_*)
+ load(source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Returns the dataset specified by the given data source and a set of options as a DataFrame,
+ * using the given schema as the schema of the DataFrame.
+ */
+ @Experimental
+ def load(
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, Some(schema), source, options)
+ DataFrame(this, LogicalRelation(resolved.relation))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path and returns the corresponding DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ */
+ @Experimental
+ def createExternalTable(tableName: String, path: String): DataFrame = {
+ val dataSourceName = conf.defaultDataSourceName
+ createExternalTable(tableName, path, dataSourceName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source
+ * and returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ path: String,
+ source: String): DataFrame = {
+ createExternalTable(tableName, source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = None,
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = Some(schema),
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index edf8a5be64ff1..e915e0e6a0ec1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -309,7 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) =>
+ case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) =>
ExecutedCommand(
CreateTempTableUsing(
tableName, userSpecifiedSchema, provider, opts)) :: Nil
@@ -318,24 +318,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) =>
+ case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) =>
val logicalPlan = sqlContext.parseSql(query)
val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan)
+ CreateTempTableUsingAsSelect(tableName, provider, mode, opts, logicalPlan)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) =>
+ case CreateTableUsingAsLogicalPlan(tableName, provider, true, mode, opts, query) =>
val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, query)
+ CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsLogicalPlan if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
case LogicalDescribeCommand(table, isExtended) =>
val resultPlan = self.sqlContext.executePlan(table).executedPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index c4e14c6c92908..f828bcdd65c9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -29,6 +28,10 @@ import org.apache.spark.sql.types.StructType
private[sql] class DefaultSource
extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
+ private def checkPath(parameters: Map[String, String]): String = {
+ parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
+ }
+
/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
@@ -52,15 +55,30 @@ private[sql] class DefaultSource
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
- val path = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val path = checkPath(parameters)
val filesystemPath = new Path(path)
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- if (fs.exists(filesystemPath)) {
- sys.error(s"path $path already exists.")
+ val doSave = if (fs.exists(filesystemPath)) {
+ mode match {
+ case SaveMode.Append =>
+ sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
+ case SaveMode.Overwrite =>
+ fs.delete(filesystemPath, true)
+ true
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"path $path already exists.")
+ case SaveMode.Ignore => false
+ }
+ } else {
+ true
+ }
+ if (doSave) {
+ // Only save data when the save mode is not ignore.
+ data.toJSON.saveAsTextFile(path)
}
- data.toJSON.saveAsTextFile(path)
createRelation(sqlContext, parameters, data.schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 04804f78f5c34..aef9c10fbcd01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -80,18 +80,45 @@ class DefaultSource
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val path = checkPath(parameters)
- ParquetRelation.createEmpty(
- path,
- data.schema.toAttributes,
- false,
- sqlContext.sparkContext.hadoopConfiguration,
- sqlContext)
-
- val relation = createRelation(sqlContext, parameters, data.schema)
- relation.asInstanceOf[ParquetRelation2].insert(data, true)
+ val filesystemPath = new Path(path)
+ val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val doSave = if (fs.exists(filesystemPath)) {
+ mode match {
+ case SaveMode.Append =>
+ sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
+ case SaveMode.Overwrite =>
+ fs.delete(filesystemPath, true)
+ true
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"path $path already exists.")
+ case SaveMode.Ignore => false
+ }
+ } else {
+ true
+ }
+
+ val relation = if (doSave) {
+ // Only save data when the save mode is not ignore.
+ ParquetRelation.createEmpty(
+ path,
+ data.schema.toAttributes,
+ false,
+ sqlContext.sparkContext.hadoopConfiguration,
+ sqlContext)
+
+ val createdRelation = createRelation(sqlContext, parameters, data.schema)
+ createdRelation.asInstanceOf[ParquetRelation2].insert(data, true)
+
+ createdRelation
+ } else {
+ // If the save mode is Ignore, we will just create the relation based on existing data.
+ createRelation(sqlContext, parameters)
+ }
+
relation
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 9f64f761002c9..6487c14b1eb8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -119,11 +119,20 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
throw new DDLException(
"a CREATE TABLE AS SELECT statement does not allow column definitions.")
}
+ // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
+ val mode = if (allowExisting.isDefined) {
+ SaveMode.Ignore
+ } else if (temp.isDefined) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+
CreateTableUsingAsSelect(tableName,
provider,
temp.isDefined,
+ mode,
options,
- allowExisting.isDefined,
query.get)
} else {
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
@@ -133,7 +142,8 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
provider,
temp.isDefined,
options,
- allowExisting.isDefined)
+ allowExisting.isDefined,
+ managedIfNoPath = false)
}
}
)
@@ -264,6 +274,7 @@ object ResolvedDataSource {
def apply(
sqlContext: SQLContext,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
val loader = Utils.getContextOrSparkClassLoader
@@ -277,7 +288,7 @@ object ResolvedDataSource {
val relation = clazz.newInstance match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sqlContext, options, data)
+ dataSource.createRelation(sqlContext, mode, options, data)
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}
@@ -307,28 +318,40 @@ private[sql] case class DescribeCommand(
new MetadataBuilder().putString("comment", "comment of the column").build())())
}
+/**
+ * Used to represent the operation of create table using a data source.
+ * @param tableName
+ * @param userSpecifiedSchema
+ * @param provider
+ * @param temporary
+ * @param options
+ * @param allowExisting If it is true, we will do nothing when the table already exists.
+ * If it is false, an exception will be thrown
+ * @param managedIfNoPath
+ */
private[sql] case class CreateTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
temporary: Boolean,
options: Map[String, String],
- allowExisting: Boolean) extends Command
+ allowExisting: Boolean,
+ managedIfNoPath: Boolean) extends Command
private[sql] case class CreateTableUsingAsSelect(
tableName: String,
provider: String,
temporary: Boolean,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: String) extends Command
private[sql] case class CreateTableUsingAsLogicalPlan(
tableName: String,
provider: String,
temporary: Boolean,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: LogicalPlan) extends Command
private [sql] case class CreateTempTableUsing(
@@ -348,12 +371,13 @@ private [sql] case class CreateTempTableUsing(
private [sql] case class CreateTempTableUsingAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
def run(sqlContext: SQLContext) = {
val df = DataFrame(sqlContext, query)
- val resolved = ResolvedDataSource(sqlContext, provider, options, df)
+ val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df)
sqlContext.registerRDDAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
@@ -364,7 +388,7 @@ private [sql] case class CreateTempTableUsingAsSelect(
/**
* Builds a map in which keys are case insensitive
*/
-protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
+protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {
val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 5eecc303ef72b..37fda7ba6e5d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -79,8 +79,27 @@ trait SchemaRelationProvider {
@DeveloperApi
trait CreatableRelationProvider {
+ /**
+ * Creates a relation with the given parameters based on the contents of the given
+ * DataFrame. The mode specifies the expected behavior of createRelation when
+ * data already exists.
+ * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
+ * Append mode means that when saving a DataFrame to a data source, if data already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
+ * existing data is expected to be overwritten by the contents of the DataFrame.
+ * ErrorIfExists mode means that when saving a DataFrame to a data source,
+ * if data already exists, an exception is expected to be thrown.
+ *
+ * @param sqlContext
+ * @param mode
+ * @param parameters
+ * @param data
+ * @return
+ */
def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
new file mode 100644
index 0000000000000..852baf0e09245
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.sources;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.test.TestSQLContext$;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaSaveLoadSuite {
+
+ private transient JavaSparkContext sc;
+ private transient SQLContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestSQLContext$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+
+ List jsonObjects = new ArrayList(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @Test
+ public void saveAndLoad() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ }
+
+ @Test
+ public void saveAndLoadWithSchema() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ List fields = new ArrayList<>();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f9ddd2ca5c567..dfb6858957fb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
+import scala.collection.JavaConversions._
+
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
@@ -52,9 +54,51 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(rdd, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
+ /**
+ * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
+ */
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param rdd the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -70,18 +114,20 @@ class QueryTest extends PlanTest {
}
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
|${rdd.logicalPlan}
|== Analyzed Plan ==
@@ -90,37 +136,21 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
- }
-
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
- }
+ return None
}
- /**
- * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
- */
- def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
- val planWithCaching = query.queryExecution.withCachedData
- val cachedData = planWithCaching collect {
- case cached: InMemoryRelation => cached
+ def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(rdd, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
-
- assert(
- cachedData.size == numCachedTables,
- s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
- planWithCaching)
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index b02389978b625..29caed9337ff6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -77,12 +77,10 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())
- dropTempTable("jsonTable")
-
- val message = intercept[RuntimeException]{
+ val message = intercept[DDLException]{
sql(
s"""
- |CREATE TEMPORARY TABLE jsonTable
+ |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
| path '${path.toString}'
@@ -91,10 +89,25 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
""".stripMargin)
}.getMessage
assert(
- message.contains(s"path ${path.toString} already exists."),
+ message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."),
"CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.")
- // Explicitly delete it.
+ // Overwrite the temporary table.
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${path.toString}'
+ |) AS
+ |SELECT a * 4 FROM jt
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT * FROM jsonTable"),
+ sql("SELECT a * 4 FROM jt").collect())
+
+ dropTempTable("jsonTable")
+ // Explicitly delete the data.
if (path.exists()) Utils.deleteRecursively(path)
sql(
@@ -104,12 +117,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
|OPTIONS (
| path '${path.toString}'
|) AS
- |SELECT a * 4 FROM jt
+ |SELECT b FROM jt
""".stripMargin)
checkAnswer(
sql("SELECT * FROM jsonTable"),
- sql("SELECT a * 4 FROM jt").collect())
+ sql("SELECT b FROM jt").collect())
dropTempTable("jsonTable")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index fe2f76cc397f5..a51004567175c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -21,10 +21,10 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.util.Utils
-
import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.{SQLConf, DataFrame}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
@@ -38,42 +38,60 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
path = util.getTempFilePath("datasource").getCanonicalFile
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
df = jsonRDD(rdd)
+ df.registerTempTable("jsonTable")
}
override def afterAll(): Unit = {
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
after {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
if (path.exists()) Utils.deleteRecursively(path)
}
def checkLoad(): Unit = {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
checkAnswer(load(path.toString), df.collect())
- checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect())
+
+ // Test if we can pick up the data source name passed in load.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect())
+ checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect())
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ checkAnswer(
+ load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)),
+ sql("SELECT b FROM jsonTable").collect())
}
- test("save with overwrite and load") {
+ test("save with path and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
df.save(path.toString)
- checkLoad
+ checkLoad()
+ }
+
+ test("save with path and datasource, and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
}
test("save with data source and options, and load") {
- df.save("org.apache.spark.sql.json", ("path", path.toString))
- checkLoad
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString))
+ checkLoad()
}
test("save and save again") {
- df.save(path.toString)
+ df.save(path.toString, "org.apache.spark.sql.json")
- val message = intercept[RuntimeException] {
- df.save(path.toString)
+ var message = intercept[RuntimeException] {
+ df.save(path.toString, "org.apache.spark.sql.json")
}.getMessage
assert(
@@ -82,7 +100,18 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
if (path.exists()) Utils.deleteRecursively(path)
- df.save(path.toString)
- checkLoad
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
+
+ df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString))
+ checkLoad()
+
+ message = intercept[RuntimeException] {
+ df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString))
+ }.getMessage
+
+ assert(
+ message.contains("Append mode is not supported"),
+ "We should complain that 'Append mode is not supported' for JSON source.")
}
}
\ No newline at end of file
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2c00659496972..7ae6ed6f841bf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -79,18 +79,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
}
- /**
- * Creates a table using the schema of the given class.
- *
- * @param tableName The name of the table to create.
- * @param allowExisting When false, an exception will be thrown if the table already exists.
- * @tparam A A case class that is used to describe the schema of the table to be created.
- */
- @Deprecated
- def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) {
- catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting)
- }
-
/**
* Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
@@ -107,70 +95,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.invalidateTable("default", tableName)
}
- @Experimental
- def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = {
- val dataSourceName = conf.defaultDataSourceName
- createTable(tableName, dataSourceName, allowExisting, ("path", path))
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = None,
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = Some(schema),
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*)
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*)
- }
-
/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 95abc363ae767..cb138be90e2e1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -216,20 +216,21 @@ private[hive] trait HiveStrategies {
object HiveDDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) =>
+ case CreateTableUsing(
+ tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
ExecutedCommand(
CreateMetastoreDataSource(
- tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil
+ tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil
- case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) =>
+ case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) =>
val logicalPlan = hiveContext.parseSql(query)
val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan)
+ CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, logicalPlan)
ExecutedCommand(cmd) :: Nil
- case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) =>
+ case CreateTableUsingAsLogicalPlan(tableName, provider, false, mode, opts, query) =>
val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query)
+ CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case _ => Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 95dcaccefdc54..f6bea1c6a6fe1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.sources.ResolvedDataSource
+import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -105,7 +107,8 @@ case class CreateMetastoreDataSource(
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String],
- allowExisting: Boolean) extends RunnableCommand {
+ allowExisting: Boolean,
+ managedIfNoPath: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
@@ -120,7 +123,7 @@ case class CreateMetastoreDataSource(
var isExternal = true
val optionsWithPath =
- if (!options.contains("path")) {
+ if (!options.contains("path") && managedIfNoPath) {
isExternal = false
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName))
} else {
@@ -141,22 +144,13 @@ case class CreateMetastoreDataSource(
case class CreateMetastoreDataSourceAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
-
- if (hiveContext.catalog.tableExists(tableName :: Nil)) {
- if (allowExisting) {
- return Seq.empty[Row]
- } else {
- sys.error(s"Table $tableName already exists.")
- }
- }
-
- val df = DataFrame(hiveContext, query)
+ var createMetastoreTable = false
var isExternal = true
val optionsWithPath =
if (!options.contains("path")) {
@@ -166,15 +160,82 @@ case class CreateMetastoreDataSourceAsSelect(
options
}
- // Create the relation based on the data of df.
- ResolvedDataSource(sqlContext, provider, optionsWithPath, df)
+ if (sqlContext.catalog.tableExists(Seq(tableName))) {
+ // Check if we need to throw an exception or just return.
+ mode match {
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"Table $tableName already exists. " +
+ s"If you want to append into it, please set mode to SaveMode.Append. " +
+ s"Or, if you want to overwrite it, please set mode to SaveMode.Overwrite.")
+ case SaveMode.Ignore =>
+ // Since the table already exists and the save mode is Ignore, we will just return.
+ return Seq.empty[Row]
+ case SaveMode.Append =>
+ // Check if the specified data source match the data source of the existing table.
+ val resolved =
+ ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath)
+ val createdRelation = LogicalRelation(resolved.relation)
+ EliminateAnalysisOperators(sqlContext.table(tableName).logicalPlan) match {
+ case l @ LogicalRelation(i: InsertableRelation) =>
+ if (l.schema != createdRelation.schema) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the schema of this " +
+ s"DataFrame does not match the schema of table $tableName."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Schemas ==
+ |${sideBySide(
+ s"== Expected Schema ==" +:
+ l.schema.treeString.split("\\\n"),
+ s"== Actual Schema ==" +:
+ createdRelation.schema.treeString.split("\\\n")).mkString("\n")}
+ """.stripMargin
+ sys.error(errorMessage)
+ } else if (i != createdRelation.relation) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the resolved relation does not " +
+ s"match the existing relation of $tableName. " +
+ s"You can use insertInto($tableName, false) to append this DataFrame to the " +
+ s"table $tableName and using its data source and options."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Relations ==
+ |${sideBySide(
+ s"== Expected Relation ==" ::
+ l.toString :: Nil,
+ s"== Actual Relation ==" ::
+ createdRelation.toString :: Nil).mkString("\n")}
+ """.stripMargin
+ sys.error(errorMessage)
+ }
+ case o =>
+ sys.error(s"Saving data in ${o.toString} is not supported.")
+ }
+ case SaveMode.Overwrite =>
+ hiveContext.sql(s"DROP TABLE IF EXISTS $tableName")
+ // Need to create the table again.
+ createMetastoreTable = true
+ }
+ } else {
+ // The table does not exist. We need to create it in metastore.
+ createMetastoreTable = true
+ }
- hiveContext.catalog.createDataSourceTable(
- tableName,
- None,
- provider,
- optionsWithPath,
- isExternal)
+ val df = DataFrame(hiveContext, query)
+
+ // Create the relation based on the data of df.
+ ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df)
+
+ if (createMetastoreTable) {
+ hiveContext.catalog.createDataSourceTable(
+ tableName,
+ Some(df.schema),
+ provider,
+ optionsWithPath,
+ isExternal)
+ }
Seq.empty[Row]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
similarity index 99%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 7c1d1133c3425..840fbc197259a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -20,9 +20,6 @@ package org.apache.spark.sql.hive.test
import java.io.File
import java.util.{Set => JavaSet}
-import scala.collection.mutable
-import scala.language.implicitConversions
-
import org.apache.hadoop.hive.ql.exec.FunctionRegistry
import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat}
import org.apache.hadoop.hive.ql.metadata.Table
@@ -30,16 +27,18 @@ import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.serde2.RegexSerDe
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.hive.serde2.avro.AvroSerDe
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.execution.HiveNativeCommand
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkContext}
+
+import scala.collection.mutable
+import scala.language.implicitConversions
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -224,11 +223,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
}
}),
TestTable("src_thrift", () => {
- import org.apache.thrift.protocol.TBinaryProtocol
- import org.apache.hadoop.hive.serde2.thrift.test.Complex
import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
- import org.apache.hadoop.mapred.SequenceFileInputFormat
- import org.apache.hadoop.mapred.SequenceFileOutputFormat
+ import org.apache.hadoop.hive.serde2.thrift.test.Complex
+ import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat}
+ import org.apache.thrift.protocol.TBinaryProtocol
val srcThrift = new Table("default", "src_thrift")
srcThrift.setFields(Nil)
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
new file mode 100644
index 0000000000000..9744a2aa3f59c
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.spark.sql.sources.SaveMode;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.QueryTest$;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.hive.test.TestHive$;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaMetastoreDataSourcesSuite {
+ private transient JavaSparkContext sc;
+ private transient HiveContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ Path hiveManagedPath;
+ FileSystem fs;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestHive$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+ hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable"));
+ fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration());
+ if (fs.exists(hiveManagedPath)){
+ fs.delete(hiveManagedPath, true);
+ }
+
+ List jsonObjects = new ArrayList(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Clean up tables.
+ sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable");
+ sqlContext.sql("DROP TABLE IF EXISTS externalTable");
+ }
+
+ @Test
+ public void saveExternalTableAndQueryIt() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ df.collectAsList());
+ }
+
+ @Test
+ public void saveExternalTableWithSchemaAndQueryIt() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ List fields = new ArrayList<>();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(
+ loadedDF,
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ }
+
+ @Test
+ public void saveTableAndQueryIt() {
+ Map options = new HashMap();
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ba391293884bd..0270e63557963 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
+import scala.collection.JavaConversions._
-import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
@@ -55,9 +53,36 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(rdd, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param rdd the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -73,18 +98,20 @@ class QueryTest extends PlanTest {
}
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
|${rdd.logicalPlan}
|== Analyzed Plan ==
@@ -93,22 +120,21 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
+ return None
}
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(rdd, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
}
-
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 869d01eb398c5..43da7519ac8db 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,7 +19,11 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.scalatest.BeforeAndAfter
+
import com.google.common.io.Files
+
+import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.types._
@@ -29,15 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._
case class TestData(key: Int, value: String)
-class InsertIntoHiveTableSuite extends QueryTest {
+class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val testData = TestHive.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString)))
- testData.registerTempTable("testData")
+
+ before {
+ // Since every we are doing tests for DDL statements,
+ // it is better to reset before every test.
+ TestHive.reset()
+ // Register the testData, which will be used in every test.
+ testData.registerTempTable("testData")
+ }
test("insertInto() HiveTable") {
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE createAndInsertTest (key int, value string)")
// Add some data.
testData.insertInto("createAndInsertTest")
@@ -68,16 +79,18 @@ class InsertIntoHiveTableSuite extends QueryTest {
}
test("Double create fails when allowExisting = false") {
- createTable[TestData]("doubleCreateAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
- intercept[org.apache.hadoop.hive.ql.metadata.HiveException] {
- createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false)
- }
+ val message = intercept[QueryExecutionException] {
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ }.getMessage
+
+ println("message!!!!" + message)
}
test("Double create does not fail when allowExisting = true") {
- createTable[TestData]("createAndInsertTest")
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)")
}
test("SPARK-4052: scala.collection.Map as value type of MapType") {
@@ -98,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
}
test("SPARK-4203:random partition directory order") {
- createTable[TestData]("tmp_table")
+ sql("CREATE TABLE tmp_table (key int, value string)")
val tmpDir = Files.createTempDir()
sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ")
sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 9ce058909f429..f94aabd29ad23 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.apache.spark.sql.sources.SaveMode
import org.scalatest.BeforeAndAfterEach
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.InvalidInputException
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql._
@@ -41,11 +43,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
override def afterEach(): Unit = {
reset()
- if (ctasPath.exists()) Utils.deleteRecursively(ctasPath)
+ if (tempPath.exists()) Utils.deleteRecursively(tempPath)
}
val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile
- var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
+ var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
test ("persistent JSON table") {
sql(
@@ -270,7 +272,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -297,7 +299,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -309,7 +311,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -325,7 +327,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE IF NOT EXISTS ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT a FROM jsonTable
""".stripMargin)
@@ -400,38 +402,122 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
sql("DROP TABLE jsonTable").collect().foreach(println)
}
- test("save and load table") {
+ test("save table") {
val originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
val df = jsonRDD(rdd)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ // Save the df as a managed table (by not specifiying the path).
df.saveAsTable("savedJsonTable")
checkAnswer(
sql("SELECT * FROM savedJsonTable"),
df.collect())
- createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false)
+ // Right now, we cannot append to an existing JSON table.
+ intercept[RuntimeException] {
+ df.saveAsTable("savedJsonTable", SaveMode.Append)
+ }
+
+ // We can overwrite it.
+ df.saveAsTable("savedJsonTable", SaveMode.Overwrite)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // When the save mode is Ignore, we will do nothing when the table already exists.
+ df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore)
+ assert(df.schema === table("savedJsonTable").schema)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Drop table will also delete the data.
+ sql("DROP TABLE savedJsonTable")
+ intercept[InvalidInputException] {
+ jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable"))
+ }
+
+ // Create an external table by specifying the path.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Data should not be deleted after we drop the table.
+ sql("DROP TABLE savedJsonTable")
+ checkAnswer(
+ jsonFile(tempPath.toString),
+ df.collect())
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ }
+
+ test("create external table") {
+ val originalDefaultSource = conf.defaultDataSourceName
+
+ val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ val df = jsonRDD(rdd)
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ createExternalTable("createdJsonTable", tempPath.toString)
assert(table("createdJsonTable").schema === df.schema)
checkAnswer(
sql("SELECT * FROM createdJsonTable"),
df.collect())
- val message = intercept[RuntimeException] {
- createTable("createdJsonTable", filePath.toString, false)
+ var message = intercept[RuntimeException] {
+ createExternalTable("createdJsonTable", filePath.toString)
}.getMessage
assert(message.contains("Table createdJsonTable already exists."),
"We should complain that ctasJsonTable already exists")
- createTable("createdJsonTable", filePath.toString, true)
- // createdJsonTable should be not changed.
- assert(table("createdJsonTable").schema === df.schema)
+ // Data should not be deleted.
+ sql("DROP TABLE createdJsonTable")
checkAnswer(
- sql("SELECT * FROM createdJsonTable"),
+ jsonFile(tempPath.toString),
df.collect())
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ // Try to specify the schema.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM createdJsonTable"),
+ sql("SELECT b FROM savedJsonTable").collect())
+
+ sql("DROP TABLE createdJsonTable")
+
+ message = intercept[RuntimeException] {
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map.empty[String, String])
+ }.getMessage
+ assert(
+ message.contains("Option 'path' not specified"),
+ "We should complain that path is not specified.")
+
+ sql("DROP TABLE savedJsonTable")
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
}
From 6195e2473b98253ccc9edc3d624ba2bf59ffc398 Mon Sep 17 00:00:00 2001
From: Michael Armbrust
Date: Tue, 10 Feb 2015 17:32:42 -0800
Subject: [PATCH 054/817] [SQL] Add an exception for analysis errors.
Also start from the bottom so we show the first error instead of the top error.
Author: Michael Armbrust
Closes #4439 from marmbrus/analysisException and squashes the following commits:
45862a0 [Michael Armbrust] fix hive test
a773bba [Michael Armbrust] Merge remote-tracking branch 'origin/master' into analysisException
f88079f [Michael Armbrust] update more cases
fede90a [Michael Armbrust] newline
fbf4bc3 [Michael Armbrust] move to sql
6235db4 [Michael Armbrust] [SQL] Add an exception for analysis errors.
---
.../apache/spark/sql/AnalysisException.scala | 23 +++++++++++++++++++
.../sql/catalyst/analysis/Analyzer.scala | 21 ++++++++++-------
.../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++++------
.../org/apache/spark/sql/SQLQuerySuite.scala | 2 +-
.../hive/execution/HiveResolutionSuite.scala | 3 ++-
5 files changed, 46 insertions(+), 17 deletions(-)
create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
new file mode 100644
index 0000000000000..871d560b9d54f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * Thrown when a query fails to analyze, usually because the query itself is invalid.
+ */
+class AnalysisException(message: String) extends Exception(message) with Serializable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fb2ff014cef07..3f0d77ad6322a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.util.collection.OpenHashSet
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -80,16 +81,18 @@ class Analyzer(catalog: Catalog,
*/
object CheckResolution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transform {
+ plan.transformUp {
case p if p.expressions.exists(!_.resolved) =>
- throw new TreeNodeException(p,
- s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
+ val missing = p.expressions.filterNot(_.resolved).map(_.prettyString).mkString(",")
+ val from = p.inputSet.map(_.name).mkString("{", ", ", "}")
+
+ throw new AnalysisException(s"Cannot resolve '$missing' given input columns $from")
case p if !p.resolved && p.childrenResolved =>
- throw new TreeNodeException(p, "Unresolved plan found")
+ throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
} match {
// As a backstop, use the root node to check that the entire plan tree is resolved.
case p if !p.resolved =>
- throw new TreeNodeException(p, "Unresolved plan in tree")
+ throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}")
case p => p
}
}
@@ -314,10 +317,11 @@ class Analyzer(catalog: Catalog,
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
- sys.error(
+ throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
- sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ throw new AnalysisException(
+ s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
@@ -329,7 +333,8 @@ class Analyzer(catalog: Catalog,
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
- case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
+ case otherType =>
+ throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 60060bf02913b..f011a5ff15ea9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -69,12 +69,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
- val e = intercept[TreeNodeException[_]] {
+ val e = intercept[AnalysisException] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
- assert(e.getMessage().toLowerCase.contains("unresolved"))
+ assert(e.getMessage().toLowerCase.contains("cannot resolve"))
assert(
caseInsensitiveAnalyze(
@@ -109,10 +109,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}
test("throw errors for unresolved attributes during analysis") {
- val e = intercept[TreeNodeException[_]] {
+ val e = intercept[AnalysisException] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
- assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
+ assert(e.getMessage().toLowerCase.contains("cannot resolve"))
}
test("throw errors for unresolved plans during analysis") {
@@ -120,10 +120,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
override lazy val resolved = false
override def output = Nil
}
- val e = intercept[TreeNodeException[_]] {
+ val e = intercept[AnalysisException] {
caseSensitiveAnalyze(UnresolvedTestPlan())
}
- assert(e.getMessage().toLowerCase.contains("unresolved plan"))
+ assert(e.getMessage().toLowerCase.contains("unresolved"))
}
test("divide should be casted into fractional types") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 11502edf972e9..55fd0b0892fa1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -589,7 +589,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
- intercept[TreeNodeException[_]] {
+ intercept[AnalysisException] {
sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index ff8130ae5f6bc..ab5f9cdddf508 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._
@@ -40,7 +41,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
"""{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested")
// there are 2 filed matching field name "b", we should report Ambiguous reference error
- val exception = intercept[RuntimeException] {
+ val exception = intercept[AnalysisException] {
sql("SELECT a[0].b from nested").queryExecution.analyzed
}
assert(exception.getMessage.contains("Ambiguous reference to fields"))
From a60aea86b4d4b716b5ec3bff776b509fe0831342 Mon Sep 17 00:00:00 2001
From: Cheng Hao
Date: Tue, 10 Feb 2015 18:19:56 -0800
Subject: [PATCH 055/817] [SPARK-5683] [SQL] Avoid multiple json generator
created
Author: Cheng Hao
Closes #4468 from chenghao-intel/json and squashes the following commits:
aeb7801 [Cheng Hao] avoid multiple json generator created
---
.../org/apache/spark/sql/DataFrameImpl.scala | 24 +++++++++++++++++--
.../org/apache/spark/sql/json/JsonRDD.scala | 13 +++-------
.../org/apache/spark/sql/json/JsonSuite.scala | 8 +++----
3 files changed, 29 insertions(+), 16 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 11f9334556981..0134b038f3c5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.io.CharArrayWriter
+
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.collection.JavaConversions._
@@ -380,8 +382,26 @@ private[sql] class DataFrameImpl protected[sql](
override def toJSON: RDD[String] = {
val rowSchema = this.schema
this.mapPartitions { iter =>
- val jsonFactory = new JsonFactory()
- iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
+ val writer = new CharArrayWriter()
+ // create the Generator without separator inserted between 2 records
+ val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+
+ new Iterator[String] {
+ override def hasNext() = iter.hasNext
+ override def next(): String = {
+ JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
+ gen.flush()
+
+ val json = writer.toString
+ if (hasNext) {
+ writer.reset()
+ } else {
+ gen.close()
+ }
+
+ json
+ }
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 33ce71b51b213..1043eefcfc6a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -23,8 +23,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
-import com.fasterxml.jackson.core.JsonProcessingException
-import com.fasterxml.jackson.core.JsonFactory
+import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException, JsonFactory}
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.rdd.RDD
@@ -430,14 +429,11 @@ private[sql] object JsonRDD extends Logging {
/** Transforms a single Row to JSON using Jackson
*
- * @param jsonFactory a JsonFactory object to construct a JsonGenerator
* @param rowSchema the schema object used for conversion
+ * @param gen a JsonGenerator object
* @param row The row to convert
*/
- private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = {
- val writer = new StringWriter()
- val gen = jsonFactory.createGenerator(writer)
-
+ private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v)
@@ -479,8 +475,5 @@ private[sql] object JsonRDD extends Logging {
}
valWriter(rowSchema, row)
- gen.close()
- writer.toString
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 7870cf9b0a868..4fc92e3e3b8c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -824,8 +824,8 @@ class JsonSuite extends QueryTest {
df1.registerTempTable("applySchema1")
val df2 = df1.toDataFrame
val result = df2.toJSON.collect()
- assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
- assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
+ assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
+ assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
val schema2 = StructType(
StructField("f1", StructType(
@@ -846,8 +846,8 @@ class JsonSuite extends QueryTest {
val df4 = df3.toDataFrame
val result2 = df4.toJSON.collect()
- assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
- assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
+ assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
+ assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
val jsonDF = jsonRDD(primitiveFieldAndType)
val primTable = jsonRDD(jsonDF.toJSON)
From ea60284095cad43aa7ac98256576375d0e91a52a Mon Sep 17 00:00:00 2001
From: Davies Liu
Date: Tue, 10 Feb 2015 19:40:12 -0800
Subject: [PATCH 056/817] [SPARK-5704] [SQL] [PySpark] createDataFrame from RDD
with columns
Deprecate inferSchema() and applySchema(), use createDataFrame() instead, which could take an optional `schema` to create an DataFrame from an RDD. The `schema` could be StructType or list of names of columns.
Author: Davies Liu
Closes #4498 from davies/create and squashes the following commits:
08469c1 [Davies Liu] remove Scala/Java API for now
c80a7a9 [Davies Liu] fix hive test
d1bd8f2 [Davies Liu] cleanup applySchema
9526e97 [Davies Liu] createDataFrame from RDD with columns
---
docs/ml-guide.md | 12 +--
docs/sql-programming-guide.md | 16 ++--
.../ml/JavaCrossValidatorExample.java | 4 +-
.../examples/ml/JavaDeveloperApiExample.java | 4 +-
.../examples/ml/JavaSimpleParamsExample.java | 4 +-
.../JavaSimpleTextClassificationPipeline.java | 4 +-
.../spark/examples/sql/JavaSparkSQL.java | 2 +-
examples/src/main/python/sql.py | 4 +-
.../spark/ml/tuning/CrossValidator.scala | 4 +-
.../apache/spark/ml/JavaPipelineSuite.java | 2 +-
.../JavaLogisticRegressionSuite.java | 2 +-
.../regression/JavaLinearRegressionSuite.java | 2 +-
.../ml/tuning/JavaCrossValidatorSuite.java | 2 +-
python/pyspark/sql/context.py | 87 +++++++++++++----
python/pyspark/sql/tests.py | 26 ++---
.../org/apache/spark/sql/SQLContext.scala | 95 +++++++++++++++++--
.../spark/sql/ColumnExpressionSuite.scala | 4 +-
.../org/apache/spark/sql/SQLQuerySuite.scala | 9 +-
.../spark/sql/execution/PlannerSuite.scala | 2 +-
.../spark/sql/jdbc/JDBCWriteSuite.scala | 18 ++--
.../org/apache/spark/sql/json/JsonSuite.scala | 4 +-
.../sql/hive/InsertIntoHiveTableSuite.scala | 8 +-
.../sql/hive/execution/SQLQuerySuite.scala | 4 +-
23 files changed, 222 insertions(+), 97 deletions(-)
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index be178d7689fdd..4bf14fba34eec 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -260,7 +260,7 @@ List localTraining = Lists.newArrayList(
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
-JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+JavaSchemaRDD training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -300,7 +300,7 @@ List localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
@@ -443,7 +443,7 @@ List localTraining = Lists.newArrayList(
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -469,7 +469,7 @@ List localTest = Lists.newArrayList(
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
JavaSchemaRDD test =
- jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
model.transform(test).registerAsTable("prediction");
@@ -626,7 +626,7 @@ List localTraining = Lists.newArrayList(
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -669,7 +669,7 @@ List localTest = Lists.newArrayList(
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerAsTable("prediction");
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 38f617d0c836c..b2b007509c735 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -225,7 +225,7 @@ public static class Person implements Serializable {
{% endhighlight %}
-A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object
+A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object
for the JavaBean.
{% highlight java %}
@@ -247,7 +247,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m
});
// Apply a schema to an RDD of JavaBeans and register it as a table.
-JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class);
+JavaSchemaRDD schemaPeople = sqlContext.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
@@ -315,7 +315,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
by `SQLContext`.
For example:
@@ -341,7 +341,7 @@ val schema =
val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
// Apply the schema to the RDD.
-val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema)
+val peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema)
// Register the SchemaRDD as a table.
peopleSchemaRDD.registerTempTable("people")
@@ -367,7 +367,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
by `JavaSQLContext`.
For example:
@@ -406,7 +406,7 @@ JavaRDD rowRDD = people.map(
});
// Apply the schema to the RDD.
-JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema);
+JavaSchemaRDD peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema);
// Register the SchemaRDD as a table.
peopleSchemaRDD.registerTempTable("people");
@@ -436,7 +436,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of tuples or lists from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
tuples or lists in the RDD created in the step 1.
-3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`.
+3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`.
For example:
{% highlight python %}
@@ -458,7 +458,7 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt
schema = StructType(fields)
# Apply the schema to the RDD.
-schemaPeople = sqlContext.applySchema(people, schema)
+schemaPeople = sqlContext.createDataFrame(people, schema)
# Register the SchemaRDD as a table.
schemaPeople.registerTempTable("people")
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 5041e0b6d34b0..5d8c5d0a92daa 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -71,7 +71,7 @@ public static void main(String[] args) {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -112,7 +112,7 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerTempTable("prediction");
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 4d9dad9f23038..19d0eb216848e 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -62,7 +62,7 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
@@ -80,7 +80,7 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
DataFrame results = model.transform(test);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index cc69e6315fdda..4c4d532388781 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -54,7 +54,7 @@ public static void main(String[] args) {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -94,7 +94,7 @@ public static void main(String[] args) {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index d929f1ad2014a..fdcfc888c235f 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -54,7 +54,7 @@ public static void main(String[] args) {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -79,7 +79,7 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
model.transform(test).registerTempTable("prediction");
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index 8defb769ffaaf..dee794840a3e1 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -74,7 +74,7 @@ public Person call(String line) {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class);
+ DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index 7f5c68e3d0fe2..47202fde7510b 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -31,7 +31,7 @@
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
# Infer schema from the first row, create a DataFrame and print the schema
- some_df = sqlContext.inferSchema(some_rdd)
+ some_df = sqlContext.createDataFrame(some_rdd)
some_df.printSchema()
# Another RDD is created from a list of tuples
@@ -40,7 +40,7 @@
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
# Create a DataFrame by applying the schema to the RDD and print the schema
- another_df = sqlContext.applySchema(another_rdd, schema)
+ another_df = sqlContext.createDataFrame(another_rdd, schema)
another_df.printSchema()
# root
# |-- age: integer (nullable = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5d51c51346665..324b1ba784387 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -76,8 +76,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val metrics = new Array[Double](epm.size)
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.applySchema(training, schema).cache()
- val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+ val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 50995ffef9ad5..0a8c9e5954676 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -45,7 +45,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
JavaRDD points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.applySchema(points, LabeledPoint.class);
+ dataset = jsql.createDataFrame(points, LabeledPoint.class);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index d4b664479255d..3f8e59de0f05c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -50,7 +50,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 40d5a92bb32af..0cc36c8d56d70 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -46,7 +46,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 074b58c07df7a..0bb6b489f2757 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -45,7 +45,7 @@ public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 882c0f98ea40b..9d29ef4839a43 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -25,7 +25,7 @@
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
@@ -47,23 +47,11 @@ def __init__(self, sparkContext, sqlContext=None):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
-
- >>> bad_rdd = sc.parallelize([1,2,3])
- >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
-
>>> from datetime import datetime
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df = sqlCtx.createDataFrame(allTypes)
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -131,6 +119,9 @@ def registerFunction(self, name, f, returnType=StringType()):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
@@ -199,7 +190,7 @@ def inferSchema(self, rdd, samplingRatio=None):
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
- if samplingRatio > 0.99:
+ if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)
@@ -211,6 +202,9 @@ def applySchema(self, rdd, schema):
"""
Applies the given schema to the given RDD of L{tuple} or L{list}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
These tuples or lists can contain complex nested structures like
lists, maps or nested rows.
@@ -300,13 +294,68 @@ def applySchema(self, rdd, schema):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
+ def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+ """
+ Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+
+ `schema` could be :class:`StructType` or a list of column names.
+
+ When `schema` is a list of column names, the type of each column
+ will be inferred from `rdd`.
+
+ When `schema` is None, it will try to infer the column name and type
+ from `rdd`, which should be an RDD of :class:`Row`, or namedtuple,
+ or dict.
+
+ If referring needed, `samplingRatio` is used to determined how many
+ rows will be used to do referring. The first row will be used if
+ `samplingRatio` is None.
+
+ :param rdd: an RDD of Row or tuple or list or dict
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd = sc.parallelize([('Alice', 1)])
+ >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql import Row
+ >>> Person = Row('name', 'age')
+ >>> person = rdd.map(lambda r: Person(*r))
+ >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("name", StringType(), True),
+ ... StructField("age", IntegerType(), True)])
+ >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3.collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ if isinstance(rdd, DataFrame):
+ raise TypeError("rdd is already a DataFrame")
+
+ if isinstance(schema, StructType):
+ return self.applySchema(rdd, schema)
+ else:
+ if isinstance(schema, (list, tuple)):
+ first = rdd.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple")
+ row_cls = Row(*schema)
+ rdd = rdd.map(lambda r: row_cls(*r))
+ return self.inferSchema(rdd, samplingRatio)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
"""
if (rdd.__class__ is DataFrame):
@@ -321,7 +370,6 @@ def parquetFile(self, *paths):
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlCtx.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -526,7 +574,6 @@ def createExternalTable(self, tableName, path=None, source=None,
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
@@ -537,7 +584,6 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a L{DataFrame}.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -685,11 +731,12 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- globs['rdd'] = sc.parallelize(
+ globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['df'] = sqlCtx.createDataFrame(rdd)
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index bc945091f7042..5e41e36897b5d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -96,7 +96,7 @@ def setUpClass(cls):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = cls.sqlCtx.createDataFrame(rdd)
@classmethod
def tearDownClass(cls):
@@ -110,14 +110,14 @@ def test_udf(self):
def test_udf2(self):
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -155,17 +155,17 @@ def test_basic_functions(self):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema())
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema())
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
@@ -187,14 +187,14 @@ def test_infer_schema(self):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
self.assertEqual([], df.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
- df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ df2 = self.sqlCtx.createDataFrame(rdd, 1.0)
self.assertEqual(df.schema(), df2.schema())
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
@@ -205,7 +205,7 @@ def test_infer_schema(self):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -214,7 +214,7 @@ def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -224,7 +224,7 @@ def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
@@ -238,7 +238,7 @@ def test_apply_schema_with_udt(self):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.applySchema(rdd, schema)
+ df = self.sqlCtx.createDataFrame(rdd, schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
@@ -246,7 +246,7 @@ def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.createDataFrame(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 801505bceb956..523911d108029 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
- * val dataFrame = sqlContext. applySchema(people, schema)
+ * val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
@@ -252,11 +252,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
- *
- * @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
@@ -264,8 +262,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@DeveloperApi
- def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- applySchema(rowRDD.rdd, schema);
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
+ * a seq of names of columns to this RDD, the data type for each column will
+ * be inferred by the first row.
+ *
+ * @param rowRDD an JavaRDD of Row
+ * @param columns names for each column
+ * @return DataFrame
+ */
+ def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
+ createDataFrame(rowRDD.rdd, columns.toSeq)
}
/**
@@ -274,7 +285,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -301,8 +312,72 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sqlContext. applySchema(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group userf
+ */
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd, beanClass)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- applySchema(rdd.rdd, beanClass)
+ createDataFrame(rdd, beanClass)
}
/**
@@ -375,7 +450,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
@@ -393,7 +468,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa4cdecbcb340..1d71039872434 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -180,7 +180,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -240,7 +240,7 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
- val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 55fd0b0892fa1..bba8899651259 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
+ val sqlCtx = TestSQLContext
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
@@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = applySchema(rowRDD2, schema2)
+ val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD3, schema2)
+ val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df108a9d262bb..c3210733f1d42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -71,7 +71,7 @@ class PlannerSuite extends FunSuite {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- applySchema(rowRDD, schema).registerTempTable("testLimit")
+ createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e581ac9b50c2b..21e70936102fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -54,7 +54,7 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
@@ -62,8 +62,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE with overwrite") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.DROPTEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
@@ -75,8 +75,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
@@ -85,8 +85,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to truncate") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
@@ -95,8 +95,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("Incompatible INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
intercept[org.apache.spark.SparkException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 4fc92e3e3b8c0..fde4b47438c56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -820,7 +820,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDataFrame
val result = df2.toJSON.collect()
@@ -841,7 +841,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD2, schema2)
+ val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDataFrame
val result2 = df4.toJSON.collect()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 43da7519ac8db..89b18c3439cf6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -97,7 +97,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil)
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m MAP )")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -142,7 +142,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array )")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
@@ -159,7 +159,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map )")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -176,7 +176,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct )")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 49fe79d989259..9a6e8650a0ec4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -34,6 +35,7 @@ case class Nested3(f3: Int)
class SQLQuerySuite extends QueryTest {
import org.apache.spark.sql.hive.test.TestHive.implicits._
+ val sqlCtx = TestHive
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
@@ -277,7 +279,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- applySchema(rowRdd, schema).registerTempTable("testTable")
+ sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes
From 45df77b8418873a00d770e435358bf603765595f Mon Sep 17 00:00:00 2001
From: Cheng Hao
Date: Tue, 10 Feb 2015 19:40:51 -0800
Subject: [PATCH 057/817] [SPARK-5709] [SQL] Add EXPLAIN support in DataFrame
API for debugging purpose
Author: Cheng Hao
Closes #4496 from chenghao-intel/df_explain and squashes the following commits:
552aa58 [Cheng Hao] Add explain support for DF
---
.../main/scala/org/apache/spark/sql/Column.scala | 8 ++++++++
.../main/scala/org/apache/spark/sql/DataFrame.scala | 6 ++++++
.../scala/org/apache/spark/sql/DataFrameImpl.scala | 13 ++++++++++---
.../org/apache/spark/sql/execution/commands.scala | 7 +++++--
.../scala/org/apache/spark/sql/hive/HiveQl.scala | 8 +++-----
5 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 1011bf0bb5ef4..b0e95908ee71a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -600,6 +600,14 @@ trait Column extends DataFrame {
def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false)
def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false)
+
+ override def explain(extended: Boolean): Unit = {
+ if (extended) {
+ println(expr)
+ } else {
+ println(expr.prettyString)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index ca8d552c5febf..17900c5ee3892 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -124,6 +124,12 @@ trait DataFrame extends RDDApi[Row] {
/** Prints the schema to the console in a nice tree format. */
def printSchema(): Unit
+ /** Prints the plans (logical and physical) to the console for debugging purpose. */
+ def explain(extended: Boolean): Unit
+
+ /** Only prints the physical plan to the console for debugging purpose. */
+ def explain(): Unit = explain(false)
+
/**
* Returns true if the `collect` and `take` methods can be run locally
* (without any Spark executors).
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 0134b038f3c5a..9638ce0865db0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -30,12 +30,11 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, ResolvedStar, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.sideBySide
-import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
+import org.apache.spark.sql.execution.{ExplainCommand, LogicalRDD, EvaluatePython}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}
@@ -115,6 +114,14 @@ private[sql] class DataFrameImpl protected[sql](
override def printSchema(): Unit = println(schema.treeString)
+ override def explain(extended: Boolean): Unit = {
+ ExplainCommand(
+ logicalPlan,
+ extended = extended).queryExecution.executedPlan.executeCollect().map {
+ r => println(r.getString(0))
+ }
+ }
+
override def isLocal: Boolean = {
logicalPlan.isInstanceOf[LocalRelation]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 335757087deef..2b1726ad4e89f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import scala.collection.mutable.ArrayBuffer
@@ -116,7 +117,9 @@ case class SetCommand(
@DeveloperApi
case class ExplainCommand(
logicalPlan: LogicalPlan,
- override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand {
+ override val output: Seq[Attribute] =
+ Seq(AttributeReference("plan", StringType, nullable = false)()),
+ extended: Boolean = false) extends RunnableCommand {
// Run through the optimizer to generate the physical plan.
override def run(sqlContext: SQLContext) = try {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 8618301ba84d6..f3c9e63652a8e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -466,23 +466,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// Just fake explain for any of the native commands.
case Token("TOK_EXPLAIN", explainArgs)
if noExplainCommands.contains(explainArgs.head.getText) =>
- ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)()))
+ ExplainCommand(NoRelation)
case Token("TOK_EXPLAIN", explainArgs)
if "TOK_CREATETABLE" == explainArgs.head.getText =>
val Some(crtTbl) :: _ :: extended :: Nil =
getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs)
ExplainCommand(
nodeToPlan(crtTbl),
- Seq(AttributeReference("plan", StringType,nullable = false)()),
- extended != None)
+ extended = extended.isDefined)
case Token("TOK_EXPLAIN", explainArgs) =>
// Ignore FORMATTED if present.
val Some(query) :: _ :: extended :: Nil =
getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
ExplainCommand(
nodeToPlan(query),
- Seq(AttributeReference("plan", StringType, nullable = false)()),
- extended != None)
+ extended = extended.isDefined)
case Token("TOK_DESCTABLE", describeArgs) =>
// Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
From 7e24249af1e2f896328ef0402fa47db78cb6f9ec Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Tue, 10 Feb 2015 19:50:44 -0800
Subject: [PATCH 058/817] [SQL][DataFrame] Fix column computability bug.
Do not recursively strip out projects. Only strip the first level project.
```scala
df("colA") + df("colB").as("colC")
```
Previously, the above would construct an invalid plan.
Author: Reynold Xin
Closes #4519 from rxin/computability and squashes the following commits:
87ff763 [Reynold Xin] Code review feedback.
015c4fc [Reynold Xin] [SQL][DataFrame] Fix column computability.
---
.../MatrixFactorizationModel.scala | 2 +-
.../scala/org/apache/spark/sql/Column.scala | 35 ++++++++++++++-----
.../org/apache/spark/sql/SQLContext.scala | 4 +--
.../spark/sql/ColumnExpressionSuite.scala | 13 +++++--
4 files changed, 39 insertions(+), 15 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 9ff06ac362a31..16979c9ed43ca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -180,7 +180,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
- import sqlContext.implicits.createDataFrame
+ import sqlContext.implicits._
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b0e95908ee71a..9d5d6e78bd487 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -66,27 +66,44 @@ trait Column extends DataFrame {
*/
def isComputable: Boolean
+ /** Removes the top project so we can get to the underlying plan. */
+ private def stripProject(p: LogicalPlan): LogicalPlan = p match {
+ case Project(_, child) => child
+ case p => sys.error("Unexpected logical plan (expected Project): " + p)
+ }
+
private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val plan = Project(Seq(expr match {
+ val namedExpr = expr match {
case named: NamedExpression => named
case unnamed: Expression => Alias(unnamed, "col")()
- }), baseCol.plan)
+ }
+ val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
Column(baseCol.sqlContext, plan, expr)
}
+ /**
+ * Construct a new column based on the expression and the other column value.
+ *
+ * There are two cases that can happen here:
+ * If otherValue is a constant, it is first turned into a Column.
+ * If otherValue is a Column, then:
+ * - If this column and otherValue are both computable and come from the same logical plan,
+ * then we can construct a ComputableColumn by applying a Project on top of the base plan.
+ * - If this column is not computable, but otherValue is computable, then we can construct
+ * a ComputableColumn based on otherValue's base plan.
+ * - If this column is computable, but otherValue is not, then we can construct a
+ * ComputableColumn based on this column's base plan.
+ * - If neither columns are computable, then we create an IncomputableColumn.
+ */
private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
- // Removes all the top level projection and subquery so we can get to the underlying plan.
- @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
- case Project(_, child) => stripProject(child)
- case Subquery(_, child) => stripProject(child)
- case _ => p
- }
-
+ // lit(otherValue) returns a Column always.
(this, lit(otherValue)) match {
case (left: ComputableColumn, right: ComputableColumn) =>
if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
computableCol(right, newExpr(right))
} else {
+ // We don't want to throw an exception here because "df1("a") === df2("b")" can be
+ // a valid expression for join conditions, even though standalone they are not valid.
Column(newExpr(right))
}
case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 523911d108029..05ac1623d78ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+ implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}
/**
* Creates a DataFrame from a local Seq of Product.
*/
- implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1d71039872434..e3e6f652ed3ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
@@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest {
shouldBeComputable(-testData2("a"))
shouldBeComputable(!testData2("a"))
- shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
- shouldBeComputable(
+ shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
+ shouldNotBeComputable(
testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
- shouldBeComputable(
+ shouldNotBeComputable(
testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
// Literals and unresolved columns should not be computable.
@@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest {
shouldNotBeComputable(sum(testData2("a")))
}
+ test("collect on column produced by a binary operator") {
+ val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+ checkAnswer(df("a") + df("b"), Seq(Row(3)))
+ checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
+ }
+
test("star") {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}
From 1cb37700753437045b15c457b983532cd5a27fa5 Mon Sep 17 00:00:00 2001
From: mcheah
Date: Tue, 10 Feb 2015 20:12:18 -0800
Subject: [PATCH 059/817] [SPARK-4879] Use driver to coordinate Hadoop output
committing for speculative tasks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Previously, SparkHadoopWriter always committed its tasks without question. The problem is that when speculation is enabled sometimes this can result in multiple tasks committing their output to the same file. Even though an HDFS-writing task may be re-launched due to speculation, the original task is not killed and may eventually commit as well.
This can cause strange race conditions where multiple tasks that commit interfere with each other, with the result being that some partition files are actually lost entirely. For more context on these kinds of scenarios, see SPARK-4879.
In Hadoop MapReduce jobs, the application master is a central coordinator that authorizes whether or not any given task can commit. Before a task commits its output, it queries the application master as to whether or not such a commit is safe, and the application master does bookkeeping as tasks are requesting commits. Duplicate tasks that would write to files that were already written to from other tasks are prohibited from committing.
This patch emulates that functionality - the crucial missing component was a central arbitrator, which is now a module called the OutputCommitCoordinator. The coordinator lives on the driver and the executors can obtain a reference to this actor and request its permission to commit. As tasks commit and are reported as completed successfully or unsuccessfully by the DAGScheduler, the commit coordinator is informed of the task completion events as well to update its internal state.
Future work includes more rigorous unit testing and extra optimizations should this patch cause a performance regression. It is unclear what the overall cost of communicating back to the driver on every hadoop-committing task will be. It's also important for those hitting this issue to backport this onto previous version of Spark because the bug has serious consequences, that is, data is lost.
Currently, the OutputCommitCoordinator is only used when `spark.speculation` is true. It can be disabled by setting `spark.hadoop.outputCommitCoordination.enabled=false` in SparkConf.
This patch is an updated version of #4155 (by mccheah), which in turn was an updated version of this PR.
Closes #4155.
Author: mcheah
Author: Josh Rosen
Closes #4066 from JoshRosen/SPARK-4879-sparkhadoopwriter-fix and squashes the following commits:
658116b [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4879-sparkhadoopwriter-fix
ed783b2 [Josh Rosen] Address Andrew’s feedback.
e7be65a [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4879-sparkhadoopwriter-fix
14861ea [Josh Rosen] splitID -> partitionID in a few places
ed8b554 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4879-sparkhadoopwriter-fix
48d5c1c [Josh Rosen] Roll back copiesRunning change in TaskSetManager
3969f5f [Josh Rosen] Re-enable guarding of commit coordination with spark.speculation setting.
ede7590 [Josh Rosen] Add test to ensure that a job that denies all commits cannot complete successfully.
97da5fe [Josh Rosen] Use actor only for RPC; call methods directly in DAGScheduler.
f582574 [Josh Rosen] Some cleanup in OutputCommitCoordinatorSuite
a7c0e29 [Josh Rosen] Create fake TaskInfo using dummy fields instead of Mockito.
997b41b [Josh Rosen] Roll back unnecessary DAGSchedulerSingleThreadedProcessLoop refactoring:
459310a [Josh Rosen] Roll back TaskSetManager changes that broke other tests.
dd00b7c [Josh Rosen] Move CommitDeniedException to executors package; remove `@DeveloperAPI` annotation.
c79df98 [Josh Rosen] Some misc. code style + doc changes:
f7d69c5 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4879-sparkhadoopwriter-fix
92e6dc9 [Josh Rosen] Bug fix: use task ID instead of StageID to index into authorizedCommitters.
b344bad [Josh Rosen] (Temporarily) re-enable “always coordinate” for testing purposes.
0aec91e [Josh Rosen] Only coordinate when speculation is enabled; add configuration option to bypass new coordination.
594e41a [mcheah] Fixing a scalastyle error
60a47f4 [mcheah] Writing proper unit test for OutputCommitCoordinator and fixing bugs.
d63f63f [mcheah] Fixing compiler error
9fe6495 [mcheah] Fixing scalastyle
1df2a91 [mcheah] Throwing exception if SparkHadoopWriter commit denied
d431144 [mcheah] Using more concurrency to process OutputCommitCoordinator requests.
c334255 [mcheah] Properly handling messages that could be sent after actor shutdown.
8d5a091 [mcheah] Was mistakenly serializing the accumulator in test suite.
9c6a4fa [mcheah] More OutputCommitCoordinator cleanup on stop()
78eb1b5 [mcheah] Better OutputCommitCoordinatorActor stopping; simpler canCommit
83de900 [mcheah] Making the OutputCommitCoordinatorMessage serializable
abc7db4 [mcheah] TaskInfo can't be null in DAGSchedulerSuite
f135a8e [mcheah] Moving the output commit coordinator from class into method.
1c2b219 [mcheah] Renaming oudated names for test function classes
66a71cd [mcheah] Removing whitespace modifications
6b543ba [mcheah] Removing redundant accumulator in unit test
c9decc6 [mcheah] Scalastyle fixes
bc80770 [mcheah] Unit tests for OutputCommitCoordinator
6e6f748 [mcheah] [SPARK-4879] Use the Spark driver to authorize Hadoop commits.
---
.../scala/org/apache/spark/SparkContext.scala | 11 +-
.../scala/org/apache/spark/SparkEnv.scala | 22 +-
.../org/apache/spark/SparkHadoopWriter.scala | 43 +++-
.../org/apache/spark/TaskEndReason.scala | 14 ++
.../executor/CommitDeniedException.scala | 35 +++
.../org/apache/spark/executor/Executor.scala | 5 +
.../apache/spark/scheduler/DAGScheduler.scala | 15 +-
.../scheduler/OutputCommitCoordinator.scala | 172 ++++++++++++++
.../spark/scheduler/TaskSchedulerImpl.scala | 9 +-
.../spark/scheduler/TaskSetManager.scala | 8 +-
.../spark/scheduler/DAGSchedulerSuite.scala | 25 +-
.../OutputCommitCoordinatorSuite.scala | 213 ++++++++++++++++++
12 files changed, 549 insertions(+), 23 deletions(-)
create mode 100644 core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
create mode 100644 core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
create mode 100644 core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 53fce6b0defdf..24a316e40e673 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -249,7 +249,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
// Create the Spark execution environment (cache, map output tracker, etc)
- private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+
+ // This function allows components created by SparkEnv to be mocked in unit tests:
+ private[spark] def createSparkEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+ }
+
+ private[spark] val env = createSparkEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b63bea5b102b6..2a0c7e756dd3a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
+import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
@@ -67,6 +68,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
private[spark] var isStopped = false
@@ -88,6 +90,7 @@ class SparkEnv (
blockManager.stop()
blockManager.master.stop()
metricsSystem.stop()
+ outputCommitCoordinator.stop()
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
@@ -169,7 +172,8 @@ object SparkEnv extends Logging {
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
- listenerBus: LiveListenerBus): SparkEnv = {
+ listenerBus: LiveListenerBus,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
@@ -181,7 +185,8 @@ object SparkEnv extends Logging {
port,
isDriver = true,
isLocal = isLocal,
- listenerBus = listenerBus
+ listenerBus = listenerBus,
+ mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}
@@ -220,7 +225,8 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
- numUsableCores: Int = 0): SparkEnv = {
+ numUsableCores: Int = 0,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -368,6 +374,13 @@ object SparkEnv extends Logging {
"levels using the RDD.persist() method instead.")
}
+ val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
+ new OutputCommitCoordinator(conf)
+ }
+ val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
+ new OutputCommitCoordinatorActor(outputCommitCoordinator))
+ outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
+
val envInstance = new SparkEnv(
executorId,
actorSystem,
@@ -384,6 +397,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ outputCommitCoordinator,
conf)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 40237596570de..6eb4537d10477 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
@@ -105,24 +106,56 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
def commit() {
val taCtxt = getTaskContext()
val cmtr = getOutputCommitter()
- if (cmtr.needsTaskCommit(taCtxt)) {
+
+ // Called after we have decided to commit
+ def performCommit(): Unit = {
try {
cmtr.commitTask(taCtxt)
- logInfo (taID + ": Committed")
+ logInfo (s"$taID: Committed")
} catch {
- case e: IOException => {
+ case e: IOException =>
logError("Error committing the output of task: " + taID.value, e)
cmtr.abortTask(taCtxt)
throw e
+ }
+ }
+
+ // First, check whether the task's output has already been committed by some other attempt
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ // The task output needs to be committed, but we don't know whether some other task attempt
+ // might be racing to commit the same output partition. Therefore, coordinate with the driver
+ // in order to determine whether this attempt can commit (see SPARK-4879).
+ val shouldCoordinateWithDriver: Boolean = {
+ val sparkConf = SparkEnv.get.conf
+ // We only need to coordinate with the driver if there are multiple concurrent task
+ // attempts, which should only occur if speculation is enabled
+ val speculationEnabled = sparkConf.getBoolean("spark.speculation", false)
+ // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
+ sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+ }
+ if (shouldCoordinateWithDriver) {
+ val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
+ val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID)
+ if (canCommit) {
+ performCommit()
+ } else {
+ val msg = s"$taID: Not committed because the driver did not authorize commit"
+ logInfo(msg)
+ // We need to abort the task so that the driver can reschedule new attempts, if necessary
+ cmtr.abortTask(taCtxt)
+ throw new CommitDeniedException(msg, jobID, splitID, attemptID)
}
+ } else {
+ // Speculation is disabled or a user has chosen to manually bypass the commit coordination
+ performCommit()
}
} else {
- logInfo ("No need to commit output of task: " + taID.value)
+ // Some other attempt committed the output, so we do nothing and signal success
+ logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}")
}
}
def commitJob() {
- // always ? Or if cmtr.needsTaskCommit ?
val cmtr = getOutputCommitter()
cmtr.commitJob(getJobContext())
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index af5fd8e0ac00c..29a5cd5fdac76 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -146,6 +146,20 @@ case object TaskKilled extends TaskFailedReason {
override def toErrorString: String = "TaskKilled (killed intentionally)"
}
+/**
+ * :: DeveloperApi ::
+ * Task requested the driver to commit, but was denied.
+ */
+@DeveloperApi
+case class TaskCommitDenied(
+ jobID: Int,
+ partitionID: Int,
+ attemptID: Int)
+ extends TaskFailedReason {
+ override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
+ s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
+}
+
/**
* :: DeveloperApi ::
* The task failed because the executor that it was running on was lost. This may happen because
diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
new file mode 100644
index 0000000000000..f7604a321f007
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{TaskCommitDenied, TaskEndReason}
+
+/**
+ * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
+ */
+class CommitDeniedException(
+ msg: String,
+ jobID: Int,
+ splitID: Int,
+ attemptID: Int)
+ extends Exception(msg) {
+
+ def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6b22dcd6f5cbf..b684fb704956b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -253,6 +253,11 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
+ case cDE: CommitDeniedException => {
+ val reason = cDE.toTaskEndReason
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ }
+
case t: Throwable => {
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 1cfe98673773a..79035571adb05 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
+import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -63,7 +63,7 @@ class DAGScheduler(
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
- clock: Clock = SystemClock)
+ clock: org.apache.spark.util.Clock = SystemClock)
extends Logging {
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
@@ -126,6 +126,8 @@ class DAGScheduler(
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
+ private val outputCommitCoordinator = env.outputCommitCoordinator
+
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessLoop.post(BeginEvent(task, taskInfo))
@@ -808,6 +810,7 @@ class DAGScheduler(
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
+ outputCommitCoordinator.stageStart(stage.id)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
@@ -865,6 +868,7 @@ class DAGScheduler(
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
+ outputCommitCoordinator.stageEnd(stage.id)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -909,6 +913,9 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
+ outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
+ event.taskInfo.attempt, event.reason)
+
// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
if (event.reason != Success) {
@@ -921,6 +928,7 @@ class DAGScheduler(
// Skip all the actions if the stage has been cancelled.
return
}
+
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = {
@@ -1073,6 +1081,9 @@ class DAGScheduler(
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
+ case commitDenied: TaskCommitDenied =>
+ // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
+
case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
new file mode 100644
index 0000000000000..759df023a6dcf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark._
+import org.apache.spark.util.{AkkaUtils, ActorLogReceive}
+
+private sealed trait OutputCommitCoordinationMessage extends Serializable
+
+private case object StopCoordinator extends OutputCommitCoordinationMessage
+private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long)
+
+/**
+ * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
+ * policy.
+ *
+ * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is
+ * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit
+ * output will be forwarded to the driver's OutputCommitCoordinator.
+ *
+ * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests)
+ * for an extensive design discussion.
+ */
+private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
+
+ // Initialized by SparkEnv
+ var coordinatorActor: Option[ActorRef] = None
+ private val timeout = AkkaUtils.askTimeout(conf)
+ private val maxAttempts = AkkaUtils.numRetries(conf)
+ private val retryInterval = AkkaUtils.retryWaitMs(conf)
+
+ private type StageId = Int
+ private type PartitionId = Long
+ private type TaskAttemptId = Long
+
+ /**
+ * Map from active stages's id => partition id => task attempt with exclusive lock on committing
+ * output for that partition.
+ *
+ * Entries are added to the top-level map when stages start and are removed they finish
+ * (either successfully or unsuccessfully).
+ *
+ * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
+ */
+ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
+ private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]
+
+ /**
+ * Called by tasks to ask whether they can commit their output to HDFS.
+ *
+ * If a task attempt has been authorized to commit, then all other attempts to commit the same
+ * task will be denied. If the authorized task attempt fails (e.g. due to its executor being
+ * lost), then a subsequent task attempt may be authorized to commit its output.
+ *
+ * @param stage the stage number
+ * @param partition the partition number
+ * @param attempt a unique identifier for this task attempt
+ * @return true if this task is authorized to commit, false otherwise
+ */
+ def canCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = {
+ val msg = AskPermissionToCommitOutput(stage, partition, attempt)
+ coordinatorActor match {
+ case Some(actor) =>
+ AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout)
+ case None =>
+ logError(
+ "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")
+ false
+ }
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]()
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage.remove(stage)
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def taskCompleted(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId,
+ reason: TaskEndReason): Unit = synchronized {
+ val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
+ logDebug(s"Ignoring task completion for completed stage")
+ return
+ })
+ reason match {
+ case Success =>
+ // The task output has been committed successfully
+ case denied: TaskCommitDenied =>
+ logInfo(
+ s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt")
+ case otherReason =>
+ logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" +
+ s" clearing lock")
+ authorizedCommitters.remove(partition)
+ }
+ }
+
+ def stop(): Unit = synchronized {
+ coordinatorActor.foreach(_ ! StopCoordinator)
+ coordinatorActor = None
+ authorizedCommittersByStage.clear()
+ }
+
+ // Marked private[scheduler] instead of private so this can be mocked in tests
+ private[scheduler] def handleAskPermissionToCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = synchronized {
+ authorizedCommittersByStage.get(stage) match {
+ case Some(authorizedCommitters) =>
+ authorizedCommitters.get(partition) match {
+ case Some(existingCommitter) =>
+ logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " +
+ s"existingCommitter = $existingCommitter")
+ false
+ case None =>
+ logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition")
+ authorizedCommitters(partition) = attempt
+ true
+ }
+ case None =>
+ logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit")
+ false
+ }
+ }
+}
+
+private[spark] object OutputCommitCoordinator {
+
+ // This actor is used only for RPC
+ class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator)
+ extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
+ sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)
+ case StopCoordinator =>
+ logInfo("OutputCommitCoordinator stopped!")
+ context.stop(self)
+ sender ! true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 79f84e70df9d5..54f8fcfc416d1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl(
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
+ val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
@@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}
+ // Label as private[scheduler] to allow tests to swap in different task set managers if necessary
+ private[scheduler] def createTaskSetManager(
+ taskSet: TaskSet,
+ maxTaskFailures: Int): TaskSetManager = {
+ new TaskSetManager(this, taskSet, maxTaskFailures)
+ }
+
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 55024ecd55e61..99a5f7117790d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -292,7 +292,8 @@ private[spark] class TaskSetManager(
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
- private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ // Labeled as protected to allow tests to override providing speculative tasks if necessary
+ protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
@@ -708,7 +709,10 @@ private[spark] class TaskSetManager(
put(info.executorId, clock.getTime())
sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
addPendingTask(index)
- if (!isZombie && state != TaskState.KILLED) {
+ if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) {
+ // If a task failed because its attempt to commit was denied, do not count this failure
+ // towards failing the stage. This is intended to prevent spurious stage failures in cases
+ // where many speculative tasks are launched and denied to commit.
assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index eb116213f69fc..9d0c1273695f6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -208,7 +208,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
- runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null))
}
}
}
@@ -219,7 +219,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
- Map[Long, Any]((accumId, 1)), null, null))
+ Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null))
}
}
}
@@ -476,7 +476,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
- null,
+ createFakeTaskInfo(),
null))
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
assert(sparkListener.failedStages.contains(1))
@@ -487,7 +487,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
- null,
+ createFakeTaskInfo(),
null))
// The SparkListener should not receive redundant failure events.
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
@@ -507,14 +507,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(newEpoch > oldEpoch)
val taskSet = taskSets(0)
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
// should work because it's a non-failed host
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null))
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
// should work because it's a new epoch
taskSet.tasks(1).epoch = newEpoch
- runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null))
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -766,5 +766,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
}
+
+ // Nothing in this test should break if the task info's fields are null, but
+ // OutputCommitCoordinator requires the task info itself to not be null.
+ private def createFakeTaskInfo(): TaskInfo = {
+ val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false)
+ info.finishTime = 1 // to prevent spurious errors in JobProgressListener
+ info
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
new file mode 100644
index 0000000000000..3cc860caa1d9b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.File
+import java.util.concurrent.TimeoutException
+
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter}
+
+import org.apache.spark._
+import org.apache.spark.rdd.{RDD, FakeOutputCommitter}
+import org.apache.spark.util.Utils
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+/**
+ * Unit tests for the output commit coordination functionality.
+ *
+ * The unit test makes both the original task and the speculated task
+ * attempt to commit, where committing is emulated by creating a
+ * directory. If both tasks create directories then the end result is
+ * a failure.
+ *
+ * Note that there are some aspects of this test that are less than ideal.
+ * In particular, the test mocks the speculation-dequeuing logic to always
+ * dequeue a task and consider it as speculated. Immediately after initially
+ * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked
+ * again to pick up the speculated task. This may be hacking the original
+ * behavior in too much of an unrealistic fashion.
+ *
+ * Also, the validation is done by checking the number of files in a directory.
+ * Ideally, an accumulator would be used for this, where we could increment
+ * the accumulator in the output committer's commitTask() call. If the call to
+ * commitTask() was called twice erroneously then the test would ideally fail because
+ * the accumulator would be incremented twice.
+ *
+ * The problem with this test implementation is that when both a speculated task and
+ * its original counterpart complete, only one of the accumulator's increments is
+ * captured. This results in a paradox where if the OutputCommitCoordinator logic
+ * was not in SparkHadoopWriter, the tests would still pass because only one of the
+ * increments would be captured even though the commit in both tasks was executed
+ * erroneously.
+ */
+class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {
+
+ var outputCommitCoordinator: OutputCommitCoordinator = null
+ var tempDir: File = null
+ var sc: SparkContext = null
+
+ before {
+ tempDir = Utils.createTempDir()
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName)
+ .set("spark.speculation", "true")
+ sc = new SparkContext(conf) {
+ override private[spark] def createSparkEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ outputCommitCoordinator = spy(new OutputCommitCoordinator(conf))
+ // Use Mockito.spy() to maintain the default infrastructure everywhere else.
+ // This mocking allows us to control the coordinator responses in test cases.
+ SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator))
+ }
+ }
+ // Use Mockito.spy() to maintain the default infrastructure everywhere else
+ val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])
+
+ doAnswer(new Answer[Unit]() {
+ override def answer(invoke: InvocationOnMock): Unit = {
+ // Submit the tasks, then force the task scheduler to dequeue the
+ // speculated task
+ invoke.callRealMethod()
+ mockTaskScheduler.backend.reviveOffers()
+ }
+ }).when(mockTaskScheduler).submitTasks(Matchers.any())
+
+ doAnswer(new Answer[TaskSetManager]() {
+ override def answer(invoke: InvocationOnMock): TaskSetManager = {
+ val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
+ new TaskSetManager(mockTaskScheduler, taskSet, 4) {
+ var hasDequeuedSpeculatedTask = false
+ override def dequeueSpeculativeTask(
+ execId: String,
+ host: String,
+ locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
+ if (!hasDequeuedSpeculatedTask) {
+ hasDequeuedSpeculatedTask = true
+ Some(0, TaskLocality.PROCESS_LOCAL)
+ } else {
+ None
+ }
+ }
+ }
+ }
+ }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any())
+
+ sc.taskScheduler = mockTaskScheduler
+ val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler)
+ sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler)
+ sc.dagScheduler = dagSchedulerWithMockTaskScheduler
+ }
+
+ after {
+ sc.stop()
+ tempDir.delete()
+ outputCommitCoordinator = null
+ }
+
+ test("Only one of two duplicate commit tasks should commit") {
+ val rdd = sc.parallelize(Seq(1), 1)
+ sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _,
+ 0 until rdd.partitions.size, allowLocal = false)
+ assert(tempDir.list().size === 1)
+ }
+
+ test("If commit fails, if task is retried it should not be locked, and will succeed.") {
+ val rdd = sc.parallelize(Seq(1), 1)
+ sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _,
+ 0 until rdd.partitions.size, allowLocal = false)
+ assert(tempDir.list().size === 1)
+ }
+
+ test("Job should not complete if all commits are denied") {
+ // Create a mock OutputCommitCoordinator that denies all attempts to commit
+ doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
+ Matchers.any(), Matchers.any(), Matchers.any())
+ val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
+ def resultHandler(x: Int, y: Unit): Unit = {}
+ val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd,
+ OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully,
+ 0 until rdd.partitions.size, resultHandler, 0)
+ // It's an error if the job completes successfully even though no committer was authorized,
+ // so throw an exception if the job was allowed to complete.
+ intercept[TimeoutException] {
+ Await.result(futureAction, 5 seconds)
+ }
+ assert(tempDir.list().size === 0)
+ }
+}
+
+/**
+ * Class with methods that can be passed to runJob to test commits with a mock committer.
+ */
+private case class OutputCommitFunctions(tempDirPath: String) {
+
+ // Mock output committer that simulates a successful commit (after commit is authorized)
+ private def successfulOutputCommitter = new FakeOutputCommitter {
+ override def commitTask(context: TaskAttemptContext): Unit = {
+ Utils.createDirectory(tempDirPath)
+ }
+ }
+
+ // Mock output committer that simulates a failed commit (after commit is authorized)
+ private def failingOutputCommitter = new FakeOutputCommitter {
+ override def commitTask(taskAttemptContext: TaskAttemptContext) {
+ throw new RuntimeException
+ }
+ }
+
+ def commitSuccessfully(iter: Iterator[Int]): Unit = {
+ val ctx = TaskContext.get()
+ runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter)
+ }
+
+ def failFirstCommitAttempt(iter: Iterator[Int]): Unit = {
+ val ctx = TaskContext.get()
+ runCommitWithProvidedCommitter(ctx, iter,
+ if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter)
+ }
+
+ private def runCommitWithProvidedCommitter(
+ ctx: TaskContext,
+ iter: Iterator[Int],
+ outputCommitter: OutputCommitter): Unit = {
+ def jobConf = new JobConf {
+ override def getOutputCommitter(): OutputCommitter = outputCommitter
+ }
+ val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
+ override def newTaskAttemptContext(
+ conf: JobConf,
+ attemptId: TaskAttemptID): TaskAttemptContext = {
+ mock(classOf[TaskAttemptContext])
+ }
+ }
+ sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber)
+ sparkHadoopWriter.commit()
+ }
+}
From b969182659aa7ea94c38329b86d98a31b23efce8 Mon Sep 17 00:00:00 2001
From: Andrew Or
Date: Tue, 10 Feb 2015 20:19:14 -0800
Subject: [PATCH 060/817] [SPARK-5729] Potential NPE in standalone REST API
If the user specifies a bad REST URL, the server will throw an NPE instead of propagating the error back. This is because the default `ErrorServlet` has the wrong prefix. This is a one line fix. I am will add more comprehensive tests in a separate patch.
Author: Andrew Or
Closes #4518 from andrewor14/rest-npe and squashes the following commits:
16b15bc [Andrew Or] Correct ErrorServlet context prefix
---
.../org/apache/spark/deploy/rest/StandaloneRestServer.scala | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index 6e4486e20fcba..acd3a2b5abe6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -77,7 +77,7 @@ private[spark] class StandaloneRestServer(
new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
- new ErrorServlet -> "/" // default handler
+ new ErrorServlet -> "/*" // default handler
)
/** Start the server and return the bound port. */
From b8f88d32723eaea4807c10b5b79d0c76f30b0510 Mon Sep 17 00:00:00 2001
From: Reynold Xin
Date: Tue, 10 Feb 2015 20:40:21 -0800
Subject: [PATCH 061/817] [SPARK-5702][SQL] Allow short names for built-in data
sources.
Also took the chance to fixed up some style ...
Author: Reynold Xin
Closes #4489 from rxin/SPARK-5702 and squashes the following commits:
74f42e3 [Reynold Xin] [SPARK-5702][SQL] Allow short names for built-in data sources.
---
.../apache/spark/sql/jdbc/JDBCRelation.scala | 26 +++----
.../apache/spark/sql/json/JSONRelation.scala | 1 +
.../org/apache/spark/sql/sources/ddl.scala | 77 ++++++++++---------
.../sql/sources/ResolvedDataSourceSuite.scala | 34 ++++++++
4 files changed, 90 insertions(+), 48 deletions(-)
create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 66ad38eb7c45b..beb76f2c553c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -48,11 +48,6 @@ private[sql] object JDBCRelation {
* exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented.
- *
- * @param column - Column name. Must refer to a column of integral type.
- * @param numPartitions - Number of partitions
- * @param minValue - Smallest value of column. Advisory.
- * @param maxValue - Largest value of column. Advisory.
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
@@ -68,12 +63,17 @@ private[sql] object JDBCRelation {
var currentValue: Long = partitioning.lowerBound
var ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
- val lowerBound = (if (i != 0) s"$column >= $currentValue" else null)
+ val lowerBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
- val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null)
- val whereClause = (if (upperBound == null) lowerBound
- else if (lowerBound == null) upperBound
- else s"$lowerBound AND $upperBound")
+ val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
+ val whereClause =
+ if (upperBound == null) {
+ lowerBound
+ } else if (lowerBound == null) {
+ upperBound
+ } else {
+ s"$lowerBound AND $upperBound"
+ }
ans += JDBCPartition(whereClause, i)
i = i + 1
}
@@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider {
if (driver != null) Class.forName(driver)
- if (
- partitionColumn != null
+ if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
}
@@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider {
private[sql] case class JDBCRelation(
url: String,
table: String,
- parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan {
+ parts: Array[Partition])(@transient val sqlContext: SQLContext)
+ extends PrunedFilteredScan {
override val schema = JDBCRDD.resolveTable(url, table)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f828bcdd65c9e..51ff2443f3717 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
+
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 6487c14b1eb8f..d3d72089c3303 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -234,65 +234,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
primitiveType
}
-object ResolvedDataSource {
- def apply(
- sqlContext: SQLContext,
- userSpecifiedSchema: Option[StructType],
- provider: String,
- options: Map[String, String]): ResolvedDataSource = {
+private[sql] object ResolvedDataSource {
+
+ private val builtinSources = Map(
+ "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
+ "json" -> classOf[org.apache.spark.sql.json.DefaultSource],
+ "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
+ )
+
+ /** Given a provider name, look up the data source class definition. */
+ def lookupDataSource(provider: String): Class[_] = {
+ if (builtinSources.contains(provider)) {
+ return builtinSources(provider)
+ }
+
val loader = Utils.getContextOrSparkClassLoader
- val clazz: Class[_] = try loader.loadClass(provider) catch {
+ try {
+ loader.loadClass(provider)
+ } catch {
case cnf: java.lang.ClassNotFoundException =>
- try loader.loadClass(provider + ".DefaultSource") catch {
+ try {
+ loader.loadClass(provider + ".DefaultSource")
+ } catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
}
}
+ }
+ /** Create a [[ResolvedDataSource]] for reading data in. */
+ def apply(
+ sqlContext: SQLContext,
+ userSpecifiedSchema: Option[StructType],
+ provider: String,
+ options: Map[String, String]): ResolvedDataSource = {
+ val clazz: Class[_] = lookupDataSource(provider)
val relation = userSpecifiedSchema match {
- case Some(schema: StructType) => {
- clazz.newInstance match {
- case dataSource: SchemaRelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
- case dataSource: org.apache.spark.sql.sources.RelationProvider =>
- sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
- }
+ case Some(schema: StructType) => clazz.newInstance() match {
+ case dataSource: SchemaRelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case dataSource: org.apache.spark.sql.sources.RelationProvider =>
+ sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
}
- case None => {
- clazz.newInstance match {
- case dataSource: RelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
- case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
- sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
- }
+
+ case None => clazz.newInstance() match {
+ case dataSource: RelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
+ sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
}
}
-
new ResolvedDataSource(clazz, relation)
}
+ /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
def apply(
sqlContext: SQLContext,
provider: String,
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
- val loader = Utils.getContextOrSparkClassLoader
- val clazz: Class[_] = try loader.loadClass(provider) catch {
- case cnf: java.lang.ClassNotFoundException =>
- try loader.loadClass(provider + ".DefaultSource") catch {
- case cnf: java.lang.ClassNotFoundException =>
- sys.error(s"Failed to load class for data source: $provider")
- }
- }
-
- val relation = clazz.newInstance match {
+ val clazz: Class[_] = lookupDataSource(provider)
+ val relation = clazz.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data)
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}
-
new ResolvedDataSource(clazz, relation)
}
}
@@ -405,6 +413,5 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St
/**
* The exception thrown from the DDL parser.
- * @param message
*/
protected[sql] class DDLException(message: String) extends Exception(message)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
new file mode 100644
index 0000000000000..8331a14c9295c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
@@ -0,0 +1,34 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import org.scalatest.FunSuite
+
+class ResolvedDataSourceSuite extends FunSuite {
+
+ test("builtin sources") {
+ assert(ResolvedDataSource.lookupDataSource("jdbc") ===
+ classOf[org.apache.spark.sql.jdbc.DefaultSource])
+
+ assert(ResolvedDataSource.lookupDataSource("json") ===
+ classOf[org.apache.spark.sql.json.DefaultSource])
+
+ assert(ResolvedDataSource.lookupDataSource("parquet") ===
+ classOf[org.apache.spark.sql.parquet.DefaultSource])
+ }
+}
From f86a89a2e081ee4593ce03398c2283fd77daac6e Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh