From 9bad062268676aaa66dcbddd1e0ab7f2d7742425 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Jan 2015 16:51:42 -0800 Subject: [PATCH 01/41] [SPARK-5355] make SparkConf thread-safe The SparkConf is not thread-safe, but is accessed by many threads. The getAll() could return parts of the configs if another thread is access it. This PR changes SparkConf.settings to a thread-safe TrieMap. Author: Davies Liu Closes #4143 from davies/safe-conf and squashes the following commits: f8fa1cf [Davies Liu] change to TrieMap a1d769a [Davies Liu] make SparkConf thread-safe --- core/src/main/scala/org/apache/spark/SparkConf.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..f9d4aa4240e9d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.mutable.{HashMap, LinkedHashSet} import org.apache.spark.serializer.KryoSerializer @@ -46,7 +47,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private[spark] val settings = new TrieMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -177,7 +178,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = settings.toArray /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { From 27bccc5ea9742ca166f6026033e7771c8a90cc0a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 21 Jan 2015 17:34:18 -0800 Subject: [PATCH 02/41] [SPARK-5202] [SQL] Add hql variable substitution support https://cwiki.apache.org/confluence/display/Hive/LanguageManual+VariableSubstitution This is a block issue for the CLI user, it impacts the existed hql scripts from Hive. Author: Cheng Hao Closes #4003 from chenghao-intel/substitution and squashes the following commits: bb41fd6 [Cheng Hao] revert the removed the implicit conversion af7c31a [Cheng Hao] add hql variable substitution support --- .../apache/spark/sql/hive/HiveContext.scala | 6 ++++-- .../sql/hive/execution/SQLQuerySuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 274f83af5ac03..9d2cfd8e0d669 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} @@ -66,11 +67,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { new this.QueryExecution { val logical = plan } override def sql(sqlText: String): SchemaRDD = { + val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { - super.sql(sqlText) + super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6bf2dbb5d6e4..7f9f1ac7cd80d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest { ) } + test("command substitution") { + sql("set tbl=src") + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") + intercept[Exception] { + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + } + + sql("set hive.variable.substitute=true") // enable the substitution + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), From ca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 21 Jan 2015 21:20:31 -0800 Subject: [PATCH 03/41] [SPARK-3424][MLLIB] cache point distances during k-means|| init This PR ports the following feature implemented in #2634 by derrickburns: * During k-means|| initialization, we should cache costs (squared distances) previously computed. It also contains the following optimization: * aggregate sumCosts directly * ran multiple (#runs) k-means++ in parallel I compared the performance locally on mnist-digit. Before this patch: ![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png) with this patch: ![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png) It is clear that each k-means|| iteration takes about the same amount of time with this patch. Authors: Derrick Burns Xiangrui Meng Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits: 0a875ec [Xiangrui Meng] address comments 4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means|| --- .../spark/mllib/clustering/KMeans.scala | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 6b5c934f015ba..fc46da3a93425 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,45 +279,80 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => + pointsWithCosts.flatMap { case (p, c) => (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) }.map((_, p)) } }.collect() + mergeNewCenters() chosen.foreach { case (r, p) => - centers(r) += p.toDense + newCenters(r) += p.toDense } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) From fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 23:06:34 -0800 Subject: [PATCH 04/41] [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 25 +++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 14 ++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } From 3027f06b4127ab23a43c5ce8cebf721e3b6766e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Jan 2015 23:41:44 -0800 Subject: [PATCH 05/41] [SPARK-5147][Streaming] Delete the received data WAL log periodically This is a refactored fix based on jerryshao 's PR #4037 This enabled deletion of old WAL files containing the received block data. Improvements over #4037 - Respecting the rememberDuration of all receiver streams. In #4037, if there were two receiver streams with multiple remember durations, the deletion would have delete based on the shortest remember duration, thus deleting data prematurely for the receiver stream with longer remember duration. - Added unit test to test creation of receiver WAL, automatic deletion, and respecting of remember duration. jerryshao I am going to merge this ASAP to make it 1.2.1 Thanks for the initial draft of this PR. Made my job much easier. Author: Tathagata Das Author: jerryshao Closes #4149 from tdas/SPARK-5147 and squashes the following commits: 730798b [Tathagata Das] Added comments. c4cf067 [Tathagata Das] Minor fixes 2579b27 [Tathagata Das] Refactored the fix to make sure that the cleanup respects the remember duration of all the receiver streams 2736fd1 [jerryshao] Delete the old WAL log periodically --- .../apache/spark/streaming/DStreamGraph.scala | 8 + .../dstream/ReceiverInputDStream.scala | 11 -- .../streaming/receiver/ReceiverMessage.scala | 5 +- .../receiver/ReceiverSupervisorImpl.scala | 9 + .../streaming/scheduler/JobGenerator.scala | 11 +- .../scheduler/ReceivedBlockTracker.scala | 1 - .../streaming/scheduler/ReceiverTracker.scala | 18 +- .../spark/streaming/util/HdfsUtils.scala | 2 +- .../spark/streaming/ReceiverSuite.scala | 157 ++++++++++++++---- 9 files changed, 172 insertions(+), 50 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..d86f852aba97e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..ef23b5c79f2e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + From 246111d179a2f3f6b97a5c2b121d8ddbfd1c9aad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Jan 2015 08:16:35 -0800 Subject: [PATCH 06/41] [SPARK-5365][MLlib] Refactor KMeans to reduce redundant data If a point is selected as new centers for many runs, it would collect many redundant data. This pr refactors it. Author: Liang-Chi Hsieh Closes #4159 from viirya/small_refactor_kmeans and squashes the following commits: 25487e6 [Liang-Chi Hsieh] Refactor codes to reduce redundant data. --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fc46da3a93425..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -328,14 +328,15 @@ class KMeans private ( val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => - (0 until runs).filter { r => + val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - }.map((_, p)) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() mergeNewCenters() - chosen.foreach { case (r, p) => - newCenters(r) += p.toDense + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } From 820ce03597350257abe0c5c96435c555038e3e6c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 22 Jan 2015 13:49:35 -0600 Subject: [PATCH 07/41] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAlloca... ...tor Author: Sandy Ryza Closes #4164 from sryza/sandy-spark-5370 and squashes the following commits: 0c8d736 [Sandy Ryza] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAllocator --- .../spark/deploy/yarn/YarnAllocator.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4c35b60c57df3..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -60,7 +60,6 @@ private[yarn] class YarnAllocator( import YarnAllocator._ - // These two complementary data structures are locked on allocatedHostToContainersMap. // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] @@ -355,20 +354,18 @@ private[yarn] class YarnAllocator( } } - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) } + + allocatedContainerToHostMap.remove(containerId) } } } From 3c3fa632e6ba45ce536065aa1145698385301fb2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 21:58:53 -0800 Subject: [PATCH 08/41] [SPARK-5233][Streaming] Fix error replaying of WAL introduced bug Because of lacking of `BlockAllocationEvent` in WAL recovery, the dangled event will mix into the new batch, which will lead to the wrong result. Details can be seen in [SPARK-5233](https://issues.apache.org/jira/browse/SPARK-5233). Author: jerryshao Closes #4032 from jerryshao/SPARK-5233 and squashes the following commits: f0b0c0b [jerryshao] Further address the comments a237c75 [jerryshao] Address the comments e356258 [jerryshao] Fix bug in unit test 558bdc3 [jerryshao] Correctly replay the WAL log when recovering from failure --- .../examples/streaming/KafkaWordCount.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 18 +++++++++++------ .../scheduler/ReceivedBlockTracker.scala | 12 ++++++++--- .../streaming/ReceivedBlockTrackerSuite.scala | 20 +++++++++---------- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 2adc63f7ff30e..387c0e421334b 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -76,7 +76,7 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args - // Zookeper connection properties + // Zookeeper connection properties val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index d86f852aba97e..8632c94349bf9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{ActorRef, ActorSystem, Props, Actor} -import org.apache.spark.{SparkException, SparkEnv, Logging} -import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} -import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} import scala.util.{Failure, Success, Try} +import akka.actor.{ActorRef, Props, Actor} + +import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} + /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent @@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => + timesToReschedule.foreach { time => + // Allocate the related blocks when recovering from failure, because some blocks that were + // added but not allocated, are dangling in the queue after recovering, we have to allocate + // those blocks to the next batch, which is the batch they were supposed to go. + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time))) - ) + } // Restart the timer timer.start(restartTime.milliseconds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index ef23b5c79f2e1..e19ac939f9ac5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker( extends Logging { private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] - + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] private val logManagerOption = createLogManager() @@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker( lastAllocatedBatchTime = batchTime allocatedBlocks } else { - throw new SparkException(s"Unexpected allocation of blocks, " + - s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + // This situation occurs when: + // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, + // possibly processed batch job or half-processed batch job need to be processed again, + // so the batchTime will be equal to lastAllocatedBatchTime. + // 2. Slow checkpointing makes recovered batch time older than WAL recovered + // lastAllocatedBatchTime. + // This situation will only occurs in recovery time. + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index de7e9d624bf6b..fbb7b0bfebafc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.allocateBlocksToBatch(2) receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty - // Verify that batch 2 cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(2) - } + // Verify that older batches have no operation on batch allocation, + // will return the same blocks as previously allocated. + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos - // Verify that older batches cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(1) - } + blockInfos.map(receivedBlockTracker.addBlock) + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } test("block addition, block to batch allocation and cleanup with write ahead log") { @@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } - + test("enabling write ahead log but not setting checkpoint dir") { conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") intercept[SparkException] { createTracker(setCheckpointDir = false) } } - + test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) From e0f7fb7f9f497b34d42f9ba147197cf9ffc51607 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 22:04:21 -0800 Subject: [PATCH 09/41] [SPARK-5315][Streaming] Fix reduceByWindow Java API not work bug `reduceByWindow` for Java API is actually not Java compatible, change to make it Java compatible. Current solution is to deprecate the old one and add a new API, but since old API actually is not correct, so is keeping the old one meaningful? just to keep the binary compatible? Also even adding new API still need to add to Mima exclusion, I'm not sure to change the API, or deprecate the old API and add a new one, which is the best solution? Author: jerryshao Closes #4104 from jerryshao/SPARK-5315 and squashes the following commits: 5bc8987 [jerryshao] Address the comment c7aa1b4 [jerryshao] Deprecate the old one to keep binary compatible 8e9dc67 [jerryshao] Fix JavaDStream reduceByWindow signature error --- project/MimaExcludes.scala | 4 ++++ .../streaming/api/java/JavaDStreamLike.scala | 20 +++++++++++++++++++ .../apache/spark/streaming/JavaAPISuite.java | 20 +++++++++++++++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 127973b658190..bc5d81f12d746 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -90,6 +90,10 @@ object MimaExcludes { // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") ) case v if v.startsWith("1.2") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index e0542eda1383f..c382a12f4d099 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @deprecated As this API is not Java compatible. */ + @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") def reduceByWindow( reduceFunc: (T, T) => T, windowDuration: Duration, @@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. However, the reduction is done incrementally diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index d92e7fe899a09..d4c40745658c2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -306,7 +306,17 @@ public void testReduce() { @SuppressWarnings("unchecked") @Test - public void testReduceByWindow() { + public void testReduceByWindowWithInverse() { + testReduceByWindow(true); + } + + @SuppressWarnings("unchecked") + @Test + public void testReduceByWindowWithoutInverse() { + testReduceByWindow(false); + } + + private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -319,8 +329,14 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + JavaDStream reducedWindowed = null; + if (withInverse) { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + } else { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new Duration(2000), new Duration(1000)); + } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); From ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 22 Jan 2015 22:09:13 -0800 Subject: [PATCH 10/41] [SPARK-3541][MLLIB] New ALS implementation with improved storage This PR adds a new ALS implementation to `spark.ml` using the pipeline API, which should be able to scale to billions of ratings. Compared with the ALS under `spark.mllib`, the new implementation 1. uses the same algorithm, 2. uses float type for ratings, 3. uses primitive arrays to avoid GC, 4. sorts and compresses ratings on each block so that we can solve least squares subproblems one by one using only one normal equation instance. The following figure shows performance comparison on copies of the Amazon Reviews dataset using a 16-node (m3.2xlarge) EC2 cluster (the same setup as in http://databricks.com/blog/2014/07/23/scalable-collaborative-filtering-with-spark-mllib.html): ![als-wip](https://cloud.githubusercontent.com/assets/829644/5659447/4c4ff8e0-96c7-11e4-87a9-73c1c63d07f3.png) I keep the `spark.mllib`'s ALS untouched for easy comparison. If the new implementation works well, I'm going to match the features of the ALS under `spark.mllib` and then make it a wrapper of the new implementation, in a separate PR. TODO: - [X] Add unit tests for implicit preferences. Author: Xiangrui Meng Closes #3720 from mengxr/SPARK-3541 and squashes the following commits: 1b9e852 [Xiangrui Meng] fix compile 5129be9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3541 dd0d0e8 [Xiangrui Meng] simplify test code c627de3 [Xiangrui Meng] add tests for implicit feedback b84f41c [Xiangrui Meng] address comments a76da7b [Xiangrui Meng] update ALS tests 2a8deb3 [Xiangrui Meng] add some ALS tests 857e876 [Xiangrui Meng] add tests for rating block and encoded block d3c1ac4 [Xiangrui Meng] rename some classes for better code readability add more doc and comments 213d163 [Xiangrui Meng] org imports 771baf3 [Xiangrui Meng] chol doc update ca9ad9d [Xiangrui Meng] add unit tests for chol b4fd17c [Xiangrui Meng] add unit tests for NormalEquation d0f99d3 [Xiangrui Meng] add tests for LocalIndexEncoder 80b8e61 [Xiangrui Meng] fix imports 4937fd4 [Xiangrui Meng] update ALS example 56c253c [Xiangrui Meng] rename product to item bce8692 [Xiangrui Meng] doc for parameters and project the output columns 3f2d81a [Xiangrui Meng] add doc 1efaecf [Xiangrui Meng] add example code 8ae86b5 [Xiangrui Meng] add a working copy of the new ALS implementation --- .../spark/examples/ml/MovieLensALS.scala | 174 ++++ .../apache/spark/ml/recommendation/ALS.scala | 973 ++++++++++++++++++ .../spark/mllib/recommendation/ALS.scala | 2 +- .../spark/ml/recommendation/ALSSuite.scala | 435 ++++++++ .../spark/mllib/recommendation/ALSSuite.scala | 2 +- 5 files changed, 1584 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala new file mode 100644 index 0000000000000..cf62772b92651 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.recommendation.ALS +import org.apache.spark.sql.{Row, SQLContext} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + * Run with + * {{{ + * bin/run-example ml.MovieLensALS + * }}} + */ +object MovieLensALS { + + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + + case class Movie(movieId: Int, title: String, genres: Seq[String]) + + object Movie { + def parseMovie(str: String): Movie = { + val fields = str.split("::") + assert(fields.size == 3) + Movie(fields(0).toInt, fields(1), fields(2).split("|")) + } + } + + case class Params( + ratings: String = null, + movies: String = null, + maxIter: Int = 10, + regParam: Double = 0.1, + rank: Int = 10, + numBlocks: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[String]("ratings") + .required() + .text("path to a MovieLens dataset of ratings") + .action((x, c) => c.copy(ratings = x)) + opt[String]("movies") + .required() + .text("path to a MovieLens dataset of movies") + .action((x, c) => c.copy(movies = x)) + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("maxIter") + .text(s"max number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Int]("numBlocks") + .text(s"number of blocks, default: ${defaultParams.numBlocks}") + .action((x, c) => c.copy(numBlocks = x)) + note( + """ + |Example command line to run this app: + | + | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ + | examples/target/scala-*/spark-examples-*.jar \ + | --rank 10 --maxIter 15 --regParam 0.1 \ + | --movies path/to/movielens/movies.dat \ + | --ratings path/to/movielens/ratings.dat + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.userId).distinct().count() + val numMovies = ratings.map(_.movieId).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val als = new ALS() + .setUserCol("userId") + .setItemCol("movieId") + .setRank(params.rank) + .setMaxIter(params.maxIter) + .setRegParam(params.regParam) + .setNumBlocks(params.numBlocks) + + val model = als.fit(training) + + val predictions = model.transform(test).cache() + + // Evaluate the model. + // TODO: Create an evaluator to compute RMSE. + val mse = predictions.select('rating, 'prediction) + .flatMap { case Row(rating: Float, prediction: Float) => + val err = rating.toDouble - prediction + val err2 = err * err + if (err2.isNaN) { + None + } else { + Some(err2) + } + }.mean() + val rmse = math.sqrt(mse) + println(s"Test RMSE = $rmse.") + + // Inspect false positives. + predictions.registerTempTable("prediction") + sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") + sqlContext.sql( + """ + |SELECT userId, prediction.movieId, title, rating, prediction + | FROM prediction JOIN movie ON prediction.movieId = movie.movieId + | WHERE rating <= 1 AND prediction >= 4 + | LIMIT 100 + """.stripMargin) + .collect() + .foreach(println) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala new file mode 100644 index 0000000000000..2d89e76a4c8b2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -0,0 +1,973 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.{util => ju} + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common params for ALS. + */ +private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam + with HasPredictionCol { + + /** Param for rank of the matrix factorization. */ + val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + def getRank: Int = get(rank) + + /** Param for number of user blocks. */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + def getNumUserBlocks: Int = get(numUserBlocks) + + /** Param for number of item blocks. */ + val numItemBlocks = + new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + def getNumItemBlocks: Int = get(numItemBlocks) + + /** Param to decide whether to use implicit preference. */ + val implicitPrefs = + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + def getImplicitPrefs: Boolean = get(implicitPrefs) + + /** Param for the alpha parameter in the implicit preference formulation. */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + def getAlpha: Double = get(alpha) + + /** Param for the column name for user ids. */ + val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + def getUserCol: String = get(userCol) + + /** Param for the column name for item ids. */ + val itemCol = + new Param[String](this, "itemCol", "column name for item ids", Some("item")) + def getItemCol: String = get(itemCol) + + /** Param for the column name for ratings. */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + def getRatingCol: String = get(ratingCol) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @param paramMap extra params + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + assert(schema(map(userCol)).dataType == IntegerType) + assert(schema(map(itemCol)).dataType== IntegerType) + val ratingType = schema(map(ratingCol)).dataType + assert(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = map(predictionCol) + assert(!schema.fieldNames.contains(predictionColName), + s"Prediction column $predictionColName already exists.") + val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + StructType(newFields) + } +} + +/** + * Model fitted by ALS. + */ +class ALSModel private[ml] ( + override val parent: ALS, + override val fittingParamMap: ParamMap, + k: Int, + userFactors: RDD[(Int, Array[Float])], + itemFactors: RDD[(Int, Array[Float])]) + extends Model[ALSModel] with ALSParams { + + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + import dataset.sqlContext._ + import org.apache.spark.ml.recommendation.ALSModel.Factor + val map = this.paramMap ++ paramMap + // TODO: Add DSL to simplify the code here. + val instanceTable = s"instance_$uid" + val userTable = s"user_$uid" + val itemTable = s"item_$uid" + val instances = dataset.as(Symbol(instanceTable)) + val users = userFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(userTable)) + val items = itemFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(itemTable)) + val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + if (userFeatures != null && itemFeatures != null) { + blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + val inputColumns = dataset.schema.fieldNames + val prediction = + predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) + val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction + instances + .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) + .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + .select(outputColumns: _*) + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private object ALSModel { + /** Case class to convert factors to SchemaRDDs */ + private case class Factor(id: Int, features: Seq[Float]) +} + +/** + * Alternating Least Squares (ALS) matrix factorization. + * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by pre-computing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * indicated user + * preferences rather than explicit ratings given to items. + */ +class ALS extends Estimator[ALSModel] with ALSParams { + + import org.apache.spark.ml.recommendation.ALS.Rating + + def setRank(value: Int): this.type = set(rank, value) + def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + def setAlpha(value: Double): this.type = set(alpha, value) + def setUserCol(value: String): this.type = set(userCol, value) + def setItemCol(value: String): this.type = set(itemCol, value) + def setRatingCol(value: String): this.type = set(ratingCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setRegParam(value: Double): this.type = set(regParam, value) + + /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + def setNumBlocks(value: Int): this.type = { + setNumUserBlocks(value) + setNumItemBlocks(value) + this + } + + setMaxIter(20) + setRegParam(1.0) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val ratings = + dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), + numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), + maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), + alpha = map(alpha)) + val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) + Params.inheritValues(map, this, model) + model + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private[recommendation] object ALS extends Logging { + + /** Rating class for better code readability. */ + private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + + /** Cholesky solver for least square problems. */ + private[recommendation] class CholeskySolver { + + private val upper = "U" + private val info = new intW(0) + + /** + * Solves a least squares problem with L2 regularization: + * + * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * + * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) + * @param lambda regularization constant, which will be scaled by n + * @return the solution x + */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val k = ne.k + // Add scaled lambda to the diagonals of AtA. + val scaledlambda = lambda * ne.n + var i = 0 + var j = 2 + while (i < ne.triK) { + ne.ata(i) += scaledlambda + i += j + j += 1 + } + lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dppsv returned $code.") + val x = new Array[Float](k) + i = 0 + while (i < k) { + x(i) = ne.atb(i).toFloat + i += 1 + } + ne.reset() + x + } + } + + /** Representing a normal equation (ALS' subproblem). */ + private[recommendation] class NormalEquation(val k: Int) extends Serializable { + + /** Number of entries in the upper triangular part of a k-by-k matrix. */ + val triK = k * (k + 1) / 2 + /** A^T^ * A */ + val ata = new Array[Double](triK) + /** A^T^ * b */ + val atb = new Array[Double](k) + /** Number of observations. */ + var n = 0 + + private val da = new Array[Double](k) + private val upper = "U" + + private def copyToDouble(a: Array[Float]): Unit = { + var i = 0 + while (i < k) { + da(i) = a(i) + i += 1 + } + } + + /** Adds an observation. */ + def add(a: Array[Float], b: Float): this.type = { + require(a.size == k) + copyToDouble(a) + blas.dspr(upper, k, 1.0, da, 1, ata) + blas.daxpy(k, b.toDouble, da, 1, atb, 1) + n += 1 + this + } + + /** + * Adds an observation with implicit feedback. Note that this does not increment the counter. + */ + def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + require(a.size == k) + // Extension to the original paper to handle b < 0. confidence is a function of |b| instead + // so that it is never negative. + val confidence = 1.0 + alpha * math.abs(b) + copyToDouble(a) + blas.dspr(upper, k, confidence - 1.0, da, 1, ata) + // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. + if (b > 0) { + blas.daxpy(k, confidence, da, 1, atb, 1) + } + this + } + + /** Merges another normal equation object. */ + def merge(other: NormalEquation): this.type = { + require(other.k == k) + blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) + n += other.n + this + } + + /** Resets everything to zero, which should be called after each solve. */ + def reset(): Unit = { + ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(atb, 0.0) + n = 0 + } + } + + /** + * Implementation of the ALS algorithm. + */ + private def train( + ratings: RDD[Rating], + rank: Int = 10, + numUserBlocks: Int = 10, + numItemBlocks: Int = 10, + maxIter: Int = 10, + regParam: Double = 1.0, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { + val userPart = new HashPartitioner(numUserBlocks) + val itemPart = new HashPartitioner(numItemBlocks) + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() + val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + // materialize blockRatings and user blocks + userOutBlocks.count() + val swappedBlockRatings = blockRatings.map { + case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => + ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) + } + val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + // materialize item blocks + itemOutBlocks.count() + var userFactors = initialize(userInBlocks, rank) + var itemFactors = initialize(itemInBlocks, rank) + if (implicitPrefs) { + for (iter <- 1 to maxIter) { + userFactors.setName(s"userFactors-$iter").persist() + val previousItemFactors = itemFactors + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, implicitPrefs, alpha) + previousItemFactors.unpersist() + itemFactors.setName(s"itemFactors-$iter").persist() + val previousUserFactors = userFactors + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, implicitPrefs, alpha) + previousUserFactors.unpersist() + } + } else { + for (iter <- 0 until maxIter) { + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder) + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder) + } + } + val userIdAndFactors = userInBlocks + .mapValues(_.srcIds) + .join(userFactors) + .values + .setName("userFactors") + .cache() + userIdAndFactors.count() + itemFactors.unpersist() + val itemIdAndFactors = itemInBlocks + .mapValues(_.srcIds) + .join(itemFactors) + .values + .setName("itemFactors") + .cache() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + val userOutput = userIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + (userOutput, itemOutput) + } + + /** + * Factor block that stores factors (Array[Float]) in an Array. + */ + private type FactorBlock = Array[Array[Float]] + + /** + * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to + * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the + * src factors in this block to send to dst block 0. + */ + private type OutBlock = Array[Array[Int]] + + /** + * In-link block for computing src (user/item) factors. This includes the original src IDs + * of the elements within this block as well as encoded dst (item/user) indices and corresponding + * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original + * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. + * For example, if we have an in-link record + * + * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * + * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which + * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * + * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can + * compute src factors one after another using only one normal equation instance. + * + * @param srcIds src ids (ordered) + * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and + * ratings are associated with srcIds(i). + * @param dstEncodedIndices encoded dst indices + * @param ratings ratings + * + * @see [[LocalIndexEncoder]] + */ + private[recommendation] case class InBlock( + srcIds: Array[Int], + dstPtrs: Array[Int], + dstEncodedIndices: Array[Int], + ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = ratings.size + + require(dstEncodedIndices.size == size) + require(dstPtrs.size == srcIds.size + 1) + } + + /** + * Initializes factors randomly given the in-link blocks. + * + * @param inBlocks in-link blocks + * @param rank rank + * @return initialized factor blocks + */ + private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + // Choose a unit vector uniformly at random from the unit sphere, but from the + // "first quadrant" where all elements are nonnegative. This can be done by choosing + // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. + // This appears to create factorizations that have a slightly better reconstruction + // (<1%) compared picking elements uniformly at random in [0,1]. + inBlocks.map { case (srcBlockId, inBlock) => + val random = new XORShiftRandom(srcBlockId) + val factors = Array.fill(inBlock.srcIds.size) { + val factor = Array.fill(rank)(random.nextGaussian().toFloat) + val nrm = blas.snrm2(rank, factor, 1) + blas.sscal(rank, 1.0f / nrm, factor, 1) + factor + } + (srcBlockId, factors) + } + } + + /** + * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. + */ + private[recommendation] + case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = srcIds.size + require(dstIds.size == size) + require(ratings.size == size) + } + + /** + * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. + */ + private[recommendation] class RatingBlockBuilder extends Serializable { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstIds = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + var size = 0 + + /** Adds a rating. */ + def add(r: Rating): this.type = { + size += 1 + srcIds += r.user + dstIds += r.item + ratings += r.rating + this + } + + /** Merges another [[RatingBlockBuilder]]. */ + def merge(other: RatingBlock): this.type = { + size += other.srcIds.size + srcIds ++= other.srcIds + dstIds ++= other.dstIds + ratings ++= other.ratings + this + } + + /** Builds a [[RatingBlock]]. */ + def build(): RatingBlock = { + RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + } + } + + /** + * Partitions raw ratings into blocks. + * + * @param ratings raw ratings + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * + * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) + */ + private def partitionRatings( + ratings: RDD[Rating], + srcPart: Partitioner, + dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + + /* The implementation produces the same result as the following but generates less objects. + + ratings.map { r => + ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + }.aggregateByKey(new RatingBlockBuilder)( + seqOp = (b, r) => b.add(r), + combOp = (b0, b1) => b0.merge(b1.build())) + .mapValues(_.build()) + */ + + val numPartitions = srcPart.numPartitions * dstPart.numPartitions + ratings.mapPartitions { iter => + val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + iter.flatMap { r => + val srcBlockId = srcPart.getPartition(r.user) + val dstBlockId = dstPart.getPartition(r.item) + val idx = srcBlockId + srcPart.numPartitions * dstBlockId + val builder = builders(idx) + builder.add(r) + if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k + builders(idx) = new RatingBlockBuilder + Iterator.single(((srcBlockId, dstBlockId), builder.build())) + } else { + Iterator.empty + } + } ++ { + builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => + val srcBlockId = idx % srcPart.numPartitions + val dstBlockId = idx / srcPart.numPartitions + ((srcBlockId, dstBlockId), block.build()) + } + } + }.groupByKey().mapValues { blocks => + val builder = new RatingBlockBuilder + blocks.foreach(builder.merge) + builder.build() + }.setName("ratingBlocks") + } + + /** + * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. + * @param encoder encoder for dst indices + */ + private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + + /** + * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. + * + * @param dstBlockId dst block ID + * @param srcIds original src IDs + * @param dstLocalIndices dst local indices + * @param ratings ratings + */ + def add( + dstBlockId: Int, + srcIds: Array[Int], + dstLocalIndices: Array[Int], + ratings: Array[Float]): this.type = { + val sz = srcIds.size + require(dstLocalIndices.size == sz) + require(ratings.size == sz) + this.srcIds ++= srcIds + this.ratings ++= ratings + var j = 0 + while (j < sz) { + this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) + j += 1 + } + this + } + + /** Builds a [[UncompressedInBlock]]. */ + def build(): UncompressedInBlock = { + new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) + } + } + + /** + * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. + */ + private[recommendation] class UncompressedInBlock( + val srcIds: Array[Int], + val dstEncodedIndices: Array[Int], + val ratings: Array[Float]) { + + /** Size the of block. */ + def size: Int = srcIds.size + + /** + * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a + * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. + */ + def compress(): InBlock = { + val sz = size + assert(sz > 0, "Empty in-link block should not exist.") + sort() + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val dstCountsBuilder = mutable.ArrayBuilder.make[Int] + var preSrcId = srcIds(0) + uniqueSrcIdsBuilder += preSrcId + var curCount = 1 + var i = 1 + var j = 0 + while (i < sz) { + val srcId = srcIds(i) + if (srcId != preSrcId) { + uniqueSrcIdsBuilder += srcId + dstCountsBuilder += curCount + preSrcId = srcId + j += 1 + curCount = 0 + } + curCount += 1 + i += 1 + } + dstCountsBuilder += curCount + val uniqueSrcIds = uniqueSrcIdsBuilder.result() + val numUniqueSrdIds = uniqueSrcIds.size + val dstCounts = dstCountsBuilder.result() + val dstPtrs = new Array[Int](numUniqueSrdIds + 1) + var sum = 0 + i = 0 + while (i < numUniqueSrdIds) { + sum += dstCounts(i) + i += 1 + dstPtrs(i) = sum + } + InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) + } + + private def sort(): Unit = { + val sz = size + // Since there might be interleaved log messages, we insert a unique id for easy pairing. + val sortId = Utils.random.nextInt() + logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") + val start = System.nanoTime() + val sorter = new Sorter(new UncompressedInBlockSort) + sorter.sort(this, 0, size, Ordering[IntWrapper]) + val duration = (System.nanoTime() - start) / 1e9 + logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") + } + } + + /** + * A wrapper that holds a primitive integer key. + * + * @see [[UncompressedInBlockSort]] + */ + private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + override def compare(that: IntWrapper): Int = { + key.compare(that.key) + } + } + + /** + * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. + */ + private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + + override def newKey(): IntWrapper = new IntWrapper() + + override def getKey( + data: UncompressedInBlock, + pos: Int, + reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data.srcIds(pos)) + } else { + reuse.key = data.srcIds(pos) + reuse + } + } + + override def getKey( + data: UncompressedInBlock, + pos: Int): IntWrapper = { + getKey(data, pos, null) + } + + private def swapElements[@specialized(Int, Float) T]( + data: Array[T], + pos0: Int, + pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + swapElements(data.srcIds, pos0, pos1) + swapElements(data.dstEncodedIndices, pos0, pos1) + swapElements(data.ratings, pos0, pos1) + } + + override def copyRange( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int, + length: Int): Unit = { + System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) + System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) + System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) + } + + override def allocate(length: Int): UncompressedInBlock = { + new UncompressedInBlock( + new Array[Int](length), new Array[Int](length), new Array[Float](length)) + } + + override def copyElement( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int): Unit = { + dst.srcIds(dstPos) = src.srcIds(srcPos) + dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) + dst.ratings(dstPos) = src.ratings(srcPos) + } + } + + /** + * Creates in-blocks and out-blocks from rating blocks. + * @param prefix prefix for in/out-block names + * @param ratingBlocks rating blocks + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * @return (in-blocks, out-blocks) + */ + private def makeBlocks( + prefix: String, + ratingBlocks: RDD[((Int, Int), RatingBlock)], + srcPart: Partitioner, + dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + val inBlocks = ratingBlocks.map { + case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => + // The implementation is a faster version of + // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap + val start = System.nanoTime() + val dstIdSet = new OpenHashSet[Int](1 << 20) + dstIds.foreach(dstIdSet.add) + val sortedDstIds = new Array[Int](dstIdSet.size) + var i = 0 + var pos = dstIdSet.nextPos(0) + while (pos != -1) { + sortedDstIds(i) = dstIdSet.getValue(pos) + pos = dstIdSet.nextPos(pos + 1) + i += 1 + } + assert(i == dstIdSet.size) + ju.Arrays.sort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + i = 0 + while (i < sortedDstIds.size) { + dstIdToLocalIndex.update(sortedDstIds(i), i) + i += 1 + } + logDebug( + "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") + val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) + (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) + }.groupByKey(new HashPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks").cache() + val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => + val encoder = new LocalIndexEncoder(dstPart.numPartitions) + val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) + var i = 0 + val seen = new Array[Boolean](dstPart.numPartitions) + while (i < srcIds.size) { + var j = dstPtrs(i) + ju.Arrays.fill(seen, false) + while (j < dstPtrs(i + 1)) { + val dstBlockId = encoder.blockId(dstEncodedIndices(j)) + if (!seen(dstBlockId)) { + activeIds(dstBlockId) += i // add the local index in this out-block + seen(dstBlockId) = true + } + j += 1 + } + i += 1 + } + activeIds.map { x => + x.result() + } + }.setName(prefix + "OutBlocks").cache() + (inBlocks, outBlocks) + } + + /** + * Compute dst factors by constructing and solving least square problems. + * + * @param srcFactorBlocks src factors + * @param srcOutBlocks src out-blocks + * @param dstInBlocks dst in-blocks + * @param rank rank + * @param regParam regularization constant + * @param srcEncoder encoder for src local indices + * @param implicitPrefs whether to use implicit preference + * @param alpha the alpha constant in the implicit preference formulation + * + * @return dst factors + */ + private def computeFactors( + srcFactorBlocks: RDD[(Int, FactorBlock)], + srcOutBlocks: RDD[(Int, OutBlock)], + dstInBlocks: RDD[(Int, InBlock)], + rank: Int, + regParam: Double, + srcEncoder: LocalIndexEncoder, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.size + val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None + val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { + case (srcBlockId, (srcOutBlock, srcFactors)) => + srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => + (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) + } + } + val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + dstInBlocks.join(merged).mapValues { + case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => + val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) + srcFactors.foreach { case (srcBlockId, factors) => + sortedSrcFactors(srcBlockId) = factors + } + val dstFactors = new Array[Array[Float]](dstIds.size) + var j = 0 + val ls = new NormalEquation(rank) + val solver = new CholeskySolver // TODO: add NNLS solver + while (j < dstIds.size) { + ls.reset() + if (implicitPrefs) { + ls.merge(YtY.get) + } + var i = srcPtrs(j) + while (i < srcPtrs(j + 1)) { + val encoded = srcEncodedIndices(i) + val blockId = srcEncoder.blockId(encoded) + val localIndex = srcEncoder.localIndex(encoded) + val srcFactor = sortedSrcFactors(blockId)(localIndex) + val rating = ratings(i) + if (implicitPrefs) { + ls.addImplicit(srcFactor, rating, alpha) + } else { + ls.add(srcFactor, rating) + } + i += 1 + } + dstFactors(j) = solver.solve(ls, regParam) + j += 1 + } + dstFactors + } + } + + /** + * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. + * Caching of the input factors is handled in [[ALS#train]]. + */ + private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { + factorBlocks.values.aggregate(new NormalEquation(rank))( + seqOp = (ne, factors) => { + factors.foreach(ne.add(_, 0.0f)) + ne + }, + combOp = (ne1, ne2) => ne1.merge(ne2)) + } + + /** + * Encoder for storing (blockId, localIndex) into a single integer. + * + * We use the leading bits (including the sign bit) to store the block id and the rest to store + * the local index. This is based on the assumption that users/items are approximately evenly + * partitioned. With this assumption, we should be able to encode two billion distinct values. + * + * @param numBlocks number of blocks + */ + private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { + + require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") + + private[this] final val numLocalIndexBits = + math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) + private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 + + /** Encodes a (blockId, localIndex) into a single integer. */ + def encode(blockId: Int, localIndex: Int): Int = { + require(blockId < numBlocks) + require((localIndex & ~localIndexMask) == 0) + (blockId << numLocalIndexBits) | localIndex + } + + /** Gets the block id from an encoded index. */ + @inline + def blockId(encoded: Int): Int = { + encoded >>> numLocalIndexBits + } + + /** Gets the local index from an encoded index. */ + @inline + def localIndex(encoded: Int): Int = { + encoded & localIndexMask + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index bee951a2e5e26..5f84677be238d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double) * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala new file mode 100644 index 0000000000000..cdd4db1b5b7dc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.util.Random + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.scalatest.FunSuite + +import org.apache.spark.Logging +import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { + + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("LocalIndexEncoder") { + val random = new Random + for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { + val encoder = new LocalIndexEncoder(numBlocks) + val maxLocalIndex = Int.MaxValue / numBlocks + val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++ + Seq((0, 0), (numBlocks - 1, maxLocalIndex)) + tests.foreach { case (blockId, localIndex) => + val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex." + val encoded = encoder.encode(blockId, localIndex) + assert(encoder.blockId(encoded) === blockId, err) + assert(encoder.localIndex(encoded) === localIndex, err) + } + } + } + + test("normal equation construction with explict feedback") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 3.0f) + .add(Array(4.0f, 5.0f), 6.0f) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 2) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5") + // b = np.matrix("3; 6") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + + val ne1 = new NormalEquation(2) + .add(Array(7.0f, 8.0f), 9.0f) + ne0.merge(ne1) + assert(ne0.n === 3) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5; 7 8") + // b = np.matrix("3; 6; 9") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f), 2.0f) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + } + intercept[IllegalArgumentException] { + val ne2 = new NormalEquation(3) + ne0.merge(ne2) + } + + ne0.reset() + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + } + + test("normal equation construction with implicit feedback") { + val k = 2 + val alpha = 0.5 + val ne0 = new NormalEquation(k) + .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) + .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) + .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 0) // addImplicit doesn't increase the count. + // NumPy code that computes the expected values: + // alpha = 0.5 + // A = np.matrix("-5 -4; -2 -1; 1 2") + // b = np.matrix("-3; 0; 3") + // b1 = b > 0 + // c = 1.0 + alpha * np.abs(b) + // C = np.diag(c.A1) + // I = np.eye(3) + // ata = A.transpose() * (C - I) * A + // atb = A.transpose() * C * b1 + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) + } + + test("CholeskySolver") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 4.0f) + .add(Array(1.0f, 3.0f), 9.0f) + .add(Array(1.0f, 4.0f), 16.0f) + val ne1 = new NormalEquation(k) + .merge(ne0) + + val chol = new CholeskySolver + val x0 = chol.solve(ne0, 0.0).map(_.toDouble) + // NumPy code that computes the expected solution: + // A = np.matrix("1 2; 1 3; 1 4") + // b = b = np.matrix("3; 6") + // x0 = np.linalg.lstsq(A, b)[0] + assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) + + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + + val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + // NumPy code that computes the expected solution, where lambda is scaled by n: + // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) + } + + test("RatingBlockBuilder") { + val emptyBuilder = new RatingBlockBuilder() + assert(emptyBuilder.size === 0) + val emptyBlock = emptyBuilder.build() + assert(emptyBlock.srcIds.isEmpty) + assert(emptyBlock.dstIds.isEmpty) + assert(emptyBlock.ratings.isEmpty) + + val builder0 = new RatingBlockBuilder() + .add(Rating(0, 1, 2.0f)) + .add(Rating(3, 4, 5.0f)) + assert(builder0.size === 2) + val builder1 = new RatingBlockBuilder() + .add(Rating(6, 7, 8.0f)) + .merge(builder0.build()) + assert(builder1.size === 3) + val block = builder1.build() + val ratings = Seq.tabulate(block.size) { i => + (block.srcIds(i), block.dstIds(i), block.ratings(i)) + }.toSet + assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f))) + } + + test("UncompressedInBlock") { + val encoder = new LocalIndexEncoder(10) + val uncompressed = new UncompressedInBlockBuilder(encoder) + .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) + .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) + .build() + assert(uncompressed.size === 5) + val records = Seq.tabulate(uncompressed.size) { i => + val dstEncodedIndex = uncompressed.dstEncodedIndices(i) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i)) + }.toSet + val expected = + Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f)) + assert(records === expected) + + val compressed = uncompressed.compress() + assert(compressed.size === 5) + assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3)) + assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) + var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] + var i = 0 + while (i < compressed.srcIds.size) { + var j = compressed.dstPtrs(i) + while (j < compressed.dstPtrs(i + 1)) { + val dstEncodedIndex = compressed.dstEncodedIndices(j) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j))) + j += 1 + } + i += 1 + } + assert(decompressed.toSet === expected) + } + + /** + * Generates an explicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genExplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val x = random.nextDouble() + if (x < totalFraction) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Test ALS using the given training/test splits and parameters. + * @param training training dataset + * @param test test dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @param targetRMSE target test RMSE + */ + def testALS( + training: RDD[Rating], + test: RDD[Rating], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3, + targetRMSE: Double = 0.05): Unit = { + val sqlContext = this.sqlContext + import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + val alpha = als.getAlpha + val model = als.fit(training) + val predictions = model.transform(test) + .select('rating, 'prediction) + .map { case Row(rating: Float, prediction: Float) => + (rating.toDouble, prediction.toDouble) + } + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE + // with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val mse = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + }.mean() + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) + } + + test("exact rank-1 matrix") { + val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1) + testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001) + testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001) + } + + test("approximate rank-1 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01) + testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02) + testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02) + } + + test("approximate rank-2 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03) + } + + test("different block settings") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) + } + } + + test("more blocks than ratings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numItemBlocks = 5, numUserBlocks = 5) + } + + test("implicit feedback") { + val (training, test) = + genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, + targetRMSE = 0.3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index f3b7bfda788fa..e9fc37e000526 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param bulkPredict flag to test bulk predicition * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into From cef1f092a628ac20709857b4388bb10e0b5143b0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 23 Jan 2015 17:53:15 -0800 Subject: [PATCH 11/41] [SPARK-5063] More helpful error messages for several invalid operations This patch adds more helpful error messages for invalid programs that define nested RDDs, broadcast RDDs, perform actions inside of transformations (e.g. calling `count()` from inside of `map()`), and call certain methods on stopped SparkContexts. Currently, these invalid programs lead to confusing NullPointerExceptions at runtime and have been a major source of questions on the mailing list and StackOverflow. In a few cases, I chose to log warnings instead of throwing exceptions in order to avoid any chance that this patch breaks programs that worked "by accident" in earlier Spark releases (e.g. programs that define nested RDDs but never run any jobs with them). In SparkContext, the new `assertNotStopped()` method is used to check whether methods are being invoked on a stopped SparkContext. In some cases, user programs will not crash in spite of calling methods on stopped SparkContexts, so I've only added `assertNotStopped()` calls to methods that always throw exceptions when called on stopped contexts (e.g. by dereferencing a null `dagScheduler` pointer). Author: Josh Rosen Closes #3884 from JoshRosen/SPARK-5063 and squashes the following commits: a38774b [Josh Rosen] Fix spelling typo a943e00 [Josh Rosen] Convert two exceptions into warnings in order to avoid breaking user programs in some edge-cases. 2d0d7f7 [Josh Rosen] Fix test to reflect 1.2.1 compatibility 3f0ea0c [Josh Rosen] Revert two unintentional formatting changes 8e5da69 [Josh Rosen] Remove assertNotStopped() calls for methods that were sometimes safe to call on stopped SC's in Spark 1.2 8cff41a [Josh Rosen] IllegalStateException fix 6ef68d0 [Josh Rosen] Fix Python line length issues. 9f6a0b8 [Josh Rosen] Add improved error messages to PySpark. 13afd0f [Josh Rosen] SparkException -> IllegalStateException 8d404f3 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-5063 b39e041 [Josh Rosen] Fix BroadcastSuite test which broadcasted an RDD 99cc09f [Josh Rosen] Guard against calling methods on stopped SparkContexts. 34833e8 [Josh Rosen] Add more descriptive error message. 57cc8a1 [Josh Rosen] Add error message when directly broadcasting RDD. 15b2e6b [Josh Rosen] [SPARK-5063] Useful error messages for nested RDDs and actions inside of transformations --- .../scala/org/apache/spark/SparkContext.scala | 62 +++++++++++++++---- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++- .../spark/broadcast/BroadcastSuite.scala | 12 +++- .../scala/org/apache/spark/rdd/RDDSuite.scala | 40 ++++++++++++ python/pyspark/context.py | 8 +++ python/pyspark/rdd.py | 11 ++++ 6 files changed, 138 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a354ed4d1486..8175d175b13df 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + @volatile private var stopped: Boolean = false + + private def assertNotStopped(): Unit = { + if (stopped) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * modified collection. Pass a copy of the argument to avoid this. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], @@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() new NewHadoopRDD(this, fClass, kClass, vClass, conf) } @@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { postApplicationEnd() ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { + if (!stopped) { + stopped = true env.metricsSystem.report() metadataCleaner.cancel() env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() + dagScheduler.stop() + dagScheduler = null taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) @@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1468,7 +1505,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getCheckpointDir = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97012c7033f9f..ab7410a1f7f99 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..af3272692d7a1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 381ee2d45630f..e33b4bbbb8e4c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4c..568e21f3803bf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4977400ac1c05..f4cfe4845dc20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -141,6 +141,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ From e224dbb011789297cd6c6ba095f702c042869ed6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 23 Jan 2015 19:25:15 -0800 Subject: [PATCH 12/41] [SPARK-5351][GraphX] Do not use Partitioner.defaultPartitioner as a partitioner of EdgeRDDImp... If the value of 'spark.default.parallelism' does not match the number of partitoins in EdgePartition(EdgeRDDImpl), the following error occurs in ReplicatedVertexView.scala:72; object GraphTest extends Logging { def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): VertexRDD[Int] = { graph.aggregateMessages( ctx => { ctx.sendToSrc(1) ctx.sendToDst(2) }, _ + _) } } val g = GraphLoader.edgeListFile(sc, "graph.txt") val rdd = GraphTest.run(g) java.lang.IllegalArgumentException: Can't zip RDDs with unequal numbers of partitions at org.apache.spark.rdd.ZippedPartitionsBaseRDD.getPartitions(ZippedPartitionsRDD.scala:57) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:206) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:204) at scala.Option.getOrElse(Option.scala:120) at org.apache.spark.rdd.RDD.partitions(RDD.scala:204) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:32) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:206) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:204) at scala.Option.getOrElse(Option.scala:120) at org.apache.spark.rdd.RDD.partitions(RDD.scala:204) at org.apache.spark.ShuffleDependency.(Dependency.scala:82) at org.apache.spark.rdd.ShuffledRDD.getDependencies(ShuffledRDD.scala:80) at org.apache.spark.rdd.RDD$$anonfun$dependencies$2.apply(RDD.scala:193) at org.apache.spark.rdd.RDD$$anonfun$dependencies$2.apply(RDD.scala:191) ... Author: Takeshi Yamamuro Closes #4136 from maropu/EdgePartitionBugFix and squashes the following commits: 0cd8942 [Ankur Dave] Use more concise getOrElse aad4a2c [Ankur Dave] Add unit test for non-default number of edge partitions 0a2f32b [Takeshi Yamamuro] Do not use Partitioner.defaultPartitioner as a partitioner of EdgeRDDImpl --- .../spark/graphx/impl/EdgeRDDImpl.scala | 4 ++-- .../org/apache/spark/graphx/GraphSuite.scala | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..f1550ac2e18ad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb6..ed9876b8dc21c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + } finally { + sc.stop() + } + } + } From 09e09c548e7722fca1cdc89bd37de2cee58f4ce9 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 23 Jan 2015 23:34:11 -0800 Subject: [PATCH 13/41] [SPARK-5058] Part 2. Typos and broken URL - Also fixed java link Author: Jongyoul Lee Closes #4172 from jongyoul/SPARK-FIXDOC and squashes the following commits: 6be03e5 [Jongyoul Lee] [SPARK-5058] Part 2. Typos and broken URL - Also fixed java link --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9f..77c0abbbacbd0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). From 0d1e67ee9b29b51bccfc8a319afe9f9b4581afc8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 24 Jan 2015 11:00:35 -0800 Subject: [PATCH 14/41] [SPARK-5214][Test] Add a test to demonstrate EventLoop can be stopped in the event thread Author: zsxwing Closes #4174 from zsxwing/SPARK-5214-unittest and squashes the following commits: 443e564 [zsxwing] Change the check interval to 5ms 7aaa2d7 [zsxwing] Add a test to demonstrate EventLoop can be stopped in the event thread --- .../apache/spark/util/EventLoopSuite.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 10541f878476c..1026cb2aa7cae 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() (1 to 100).foreach(eventLoop.post) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert((1 to 100) === buffer.toSeq) } eventLoop.stop() @@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) } eventLoop.stop() @@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) assert(eventLoop.isActive) } @@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts { }.start() } - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(threadNum * eventsFromEachThread === receivedEventsCount) } eventLoop.stop() @@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts { } assert(false === eventLoop.isActive) } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } } From d22ca1e921d792e49600c4c484a846e739c340ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 25 Jan 2015 00:24:59 -0800 Subject: [PATCH 15/41] Closes #4157 From 412a58e118ef083ea1d1d6daccd9c531852baf53 Mon Sep 17 00:00:00 2001 From: Idan Zalzberg Date: Sun, 25 Jan 2015 11:28:05 -0800 Subject: [PATCH 16/41] Add comment about defaultMinPartitions Added a comment about using math.min for choosing default partition count Author: Idan Zalzberg Closes #4102 from idanz/patch-2 and squashes the following commits: 50e9d58 [Idan Zalzberg] Update SparkContext.scala --- core/src/main/scala/org/apache/spark/SparkContext.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8175d175b13df..4c4ee04cc515e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1514,7 +1514,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) From 2d9887bae52a1b4d01765c28dfe333396338bfe6 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 25 Jan 2015 14:17:59 -0800 Subject: [PATCH 17/41] [SPARK-5401] set executor ID before creating MetricsSystem Author: Ryan Williams Closes #4194 from ryan-williams/metrics and squashes the following commits: 7c5a33f [Ryan Williams] set executor ID before creating MetricsSystem --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++++ .../main/scala/org/apache/spark/metrics/MetricsSystem.scala | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33f..1264a8126153b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -326,6 +326,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01dd..83e8eb71260eb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,8 +130,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } From aea25482c370fbcf712a464501605bc16ee4ed5d Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Sun, 25 Jan 2015 14:20:02 -0800 Subject: [PATCH 18/41] [SPARK-5402] log executor ID at executor-construction time also rename "slaveHostname" to "executorHostname" Author: Ryan Williams Closes #4195 from ryan-williams/exec and squashes the following commits: e60a7bb [Ryan Williams] log executor ID at executor-construction time --- .../scala/org/apache/spark/executor/Executor.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a14093..d8c2e41a7c715 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, isLocal: Boolean = false) extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -58,12 +61,12 @@ private[spark] class Executor( @volatile private var isStopped = false // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. From c586b45dd25b50be7f195df2ce91b307e1ed71a9 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:08:05 -0800 Subject: [PATCH 19/41] SPARK-3852 [DOCS] Document spark.driver.extra* configs As per the JIRA. I copied the `spark.executor.extra*` text, but removed info that appears to be specific to the `executor` config and not `driver`. Author: Sean Owen Closes #4185 from srowen/SPARK-3852 and squashes the following commits: f60a8a1 [Sean Owen] Document spark.driver.extra* configs --- docs/configuration.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index efbab4085317a..7c5b6d011cfd3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment + + + + + + + + + + + + + + + From 383425ab70aef4d4961f60309309ba5cb663db5f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:11:57 -0800 Subject: [PATCH 20/41] SPARK-3782 [CORE] Direct use of log4j in AkkaUtils interferes with certain logging configurations Although the underlying issue can I think be solved by having user code use slf4j 1.7.6+, it might be helpful and consistent to update Spark's slf4j too. I see no reason to believe it would be incompatible with other 1.7.x releases: http://www.slf4j.org/news.html Lots of different version of slf4j are in use in the wild and anecdotally I have never seen an issue mixing them. Author: Sean Owen Closes #4184 from srowen/SPARK-3782 and squashes the following commits: 5608d28 [Sean Owen] Update slf4j to 1.7.10 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b993391b15042..05cb3797fc55b 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 1.0.4 2.4.1 From 1c30afdf94b27e1ad65df0735575306e65d148a1 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Sun, 25 Jan 2015 15:15:09 -0800 Subject: [PATCH 21/41] SPARK-5382: Use SPARK_CONF_DIR in spark-class if it is defined Author: Jacek Lewandowski Closes #4179 from jacek-lewandowski/SPARK-5382-1.3 and squashes the following commits: 55d7791 [Jacek Lewandowski] SPARK-5382: Use SPARK_CONF_DIR in spark-class if it is defined --- bin/spark-class | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index 1b945461fabc8..2f0441bb3c1c2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" . "$FWDIR"/bin/load-spark-env.sh @@ -120,8 +121,8 @@ fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" +if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then + JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! From 9f6435763d173d2abf82d16b5878983fa8bf3419 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 15:25:05 -0800 Subject: [PATCH 22/41] SPARK-4506 [DOCS] Addendum: Update more docs to reflect that standalone works in cluster mode This is a trivial addendum to SPARK-4506, which was already resolved. noted by Asim Jalis in SPARK-4506. Author: Sean Owen Closes #4160 from srowen/SPARK-4506 and squashes the following commits: 5f5f7df [Sean Owen] Update more docs to reflect that standalone works in cluster mode --- docs/submitting-applications.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfafe..14a87f8436984 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. From 8f5c827b01026bf45fc774ed7387f11a941abea8 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 25 Jan 2015 15:34:20 -0800 Subject: [PATCH 23/41] [SPARK-5344][WebUI] HistoryServer cannot recognize that inprogress file was renamed to completed file `FsHistoryProvider` tries to update application status but if `checkForLogs` is called before `.inprogress` file is renamed to completed file, the file is not recognized as completed. Author: Kousuke Saruta Closes #4132 from sarutak/SPARK-5344 and squashes the following commits: 9658008 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5344 d2c72b6 [Kousuke Saruta] Fixed update issue of FsHistoryProvider --- .../deploy/history/FsHistoryProvider.scala | 4 +++- .../history/FsHistoryProviderSuite.scala | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b78..0ae45f4ad9130 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!logInfos.isEmpty) { val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { + if (!newApps.contains(info.id) || + newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { newApps += (info.id -> info) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e7..3fbc1a21d10ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set("spark.testing", "true") + val provider = new FsHistoryProvider(conf) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(new File(testDir, "app1")) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { val out = From fc2168f04e9b2c7ce45f59db0bf632a26d56c72b Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Sun, 25 Jan 2015 16:48:26 -0800 Subject: [PATCH 24/41] [SPARK-5326] Show fetch wait time as optional metric in the UI With this change, here's what the UI looks like: ![image](https://cloud.githubusercontent.com/assets/1108612/5809994/1ec8a904-9ff4-11e4-8f24-6a59a1a858f7.png) If you want to locally test this, you need to spin up multiple executors, because the shuffle read metrics are only shown for data read remotely. Author: Kay Ousterhout Closes #4110 from kayousterhout/SPARK-5326 and squashes the following commits: 610051e [Kay Ousterhout] Josh style comments 5feaa28 [Kay Ousterhout] What is the difference here?? aa129cb [Kay Ousterhout] Removed inadvertent change 721c742 [Kay Ousterhout] Improved tooltip f3a7111 [Kay Ousterhout] Style fix 679b4e9 [Kay Ousterhout] [SPARK-5326] Show fetch wait time as optional metric in the UI --- .../org/apache/spark/ui/static/webui.css | 3 +- .../scala/org/apache/spark/ui/ToolTips.scala | 6 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 40 ++++++++++++++++++- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a1f7133f897ee..f23ba9dba167f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -190,6 +190,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, +.getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0a..4307029d44fbb 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,8 +24,10 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." val INPUT = "Bytes read from Hadoop or from Spark storage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c0..d8be1b20b3acd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time + {if (hasShuffleRead) { +
  • + + + Shuffle Read Blocked Time + +
  • + }}
  • @@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read", "")) + } else { + Nil + }} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) else Nil} ++ @@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val outputQuantiles =
  • +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadQuantiles} + } else { + Nil + }, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) @@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") + val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime) + val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasShuffleRead) { + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde42..37cf2c207ba40 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } From 0528b85cf96f9c9c074b5fbb5b9c5dd8071c0bc7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 25 Jan 2015 19:16:44 -0800 Subject: [PATCH 25/41] SPARK-4430 [STREAMING] [TEST] Apache RAT Checks fail spuriously on test files Another trivial one. The RAT failure was due to temp files from `FailureSuite` not being cleaned up. This just makes the cleanup more reliable by using the standard temp dir mechanism. Author: Sean Owen Closes #4189 from srowen/SPARK-4430 and squashes the following commits: 9ea63ff [Sean Owen] Properly acquire a temp directory to ensure it is cleaned up at shutdown, which helps avoid a RAT check failure --- .../scala/org/apache/spark/streaming/FailureSuite.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 40434b1f9b709..6500608bba87c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -28,21 +28,16 @@ import java.io.File */ class FailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + val directory = Utils.createTempDir().getAbsolutePath val numBatches = 30 override def batchDuration = Milliseconds(1000) override def useManualClock = false - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } - override def afterFunction() { - super.afterFunction() Utils.deleteRecursively(new File(directory)) + super.afterFunction() } test("multiple failures with map") { From 8df9435512e4b4df12b55143a910d9ce3d4bd62c Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 25 Jan 2015 19:28:53 -0800 Subject: [PATCH 26/41] [SPARK-5268] don't stop CoarseGrainedExecutorBackend for irrelevant DisassociatedEvent https://issues.apache.org/jira/browse/SPARK-5268 In CoarseGrainedExecutorBackend, we subscribe DisassociatedEvent in executor backend actor and exit the program upon receive such event... let's consider the following case The user may develop an Akka-based program which starts the actor with Spark's actor system and communicate with an external actor system (e.g. an Akka-based receiver in spark streaming which communicates with an external system) If the external actor system fails or disassociates with the actor within spark's system with purpose, we may receive DisassociatedEvent and the executor is restarted. This is not the expected behavior..... ---- This is a simple fix to check the event before making the quit decision Author: CodingCat Closes #4063 from CodingCat/SPARK-5268 and squashes the following commits: 4d7d48e [CodingCat] simplify the log 18c36f4 [CodingCat] more descriptive log f299e0b [CodingCat] clean log 1632e79 [CodingCat] check whether DisassociatedEvent is relevant before quit --- .../spark/executor/CoarseGrainedExecutorBackend.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d71..823825302658c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend( } case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) + if (x.remoteAddress == driver.anchorPath.address) { + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"Received irrelevant DisassociatedEvent $x") + } case StopExecutor => logInfo("Driver commanded a shutdown") From 81251682edfb6a0de47e9dd72b2f63ee298fca33 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 25 Jan 2015 22:18:09 -0800 Subject: [PATCH 27/41] [SPARK-5384][mllib] Vectors.sqdist returns inconsistent results for sparse/dense vectors when the vectors have different lengths JIRA issue: https://issues.apache.org/jira/browse/SPARK-5384 Currently `Vectors.sqdist` return inconsistent result for sparse/dense vectors when the vectors have different lengths, please refer to JIRA for sample PR scope: Unify the sqdist logic for dense/sparse vectors and fix the inconsistency, also remove the possible sparse to dense conversion in the original code. For reviewers: Maybe we should first discuss what's the correct behavior. 1. Vectors for sqdist must have the same length, like in breeze? 2. If they can have different lengths, what's the correct result for sqdist? (should the extra part get into calculation?) I'll update PR with more optimization and additional ut afterwards. Thanks. Author: Yuhao Yang Closes #4183 from hhbyyh/fixDouble and squashes the following commits: 1f17328 [Yuhao Yang] limit PR scope to size constraints only 54cbf97 [Yuhao Yang] fix Vectors.sqdist inconsistence --- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7ee0224ad4662..b3022add38469 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -333,7 +333,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -341,8 +341,9 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, "vector dimension mismatch") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -350,12 +351,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -397,7 +398,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { From 142093179a4c40bdd90744191034de7b94a963ff Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 26 Jan 2015 12:51:32 -0800 Subject: [PATCH 28/41] [SPARK-5355] use j.u.c.ConcurrentHashMap instead of TrieMap j.u.c.ConcurrentHashMap is more battle tested. cc rxin JoshRosen pwendell Author: Davies Liu Closes #4208 from davies/safe-conf and squashes the following commits: c2182dc [Davies Liu] address comments, fix tests 3a1d821 [Davies Liu] fix test da14ced [Davies Liu] Merge branch 'master' of github.com:apache/spark into safe-conf ae4d305 [Davies Liu] change to j.u.c.ConcurrentMap f8fa1cf [Davies Liu] change to TrieMap a1d769a [Davies Liu] make SparkConf thread-safe --- .../scala/org/apache/spark/SparkConf.scala | 38 ++++++++++--------- .../deploy/worker/WorkerArgumentsTest.scala | 4 +- .../apache/spark/storage/LocalDirsSuite.scala | 2 +- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f9d4aa4240e9d..cd91c8f87547b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,9 +17,11 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap + import scala.collection.JavaConverters._ -import scala.collection.concurrent.TrieMap -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer /** @@ -47,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new TrieMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + set(k, v) } } @@ -64,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + settings.put(key, value) this } @@ -130,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + this.settings.putAll(settings.toMap.asJava) this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value - } + settings.putIfAbsent(key, value) this } @@ -164,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) } /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(key)) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -225,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(key) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -241,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -266,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -346,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd7..372d7aa453008 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336de..8cf951adb354b 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since From c094c73270837602b875b18998133a2364a89e45 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 26 Jan 2015 13:07:49 -0800 Subject: [PATCH 29/41] [SPARK-5339][BUILD] build/mvn doesn't work because of invalid URL for maven's tgz. build/mvn will automatically download tarball of maven. But currently, the URL is invalid. Author: Kousuke Saruta Closes #4124 from sarutak/SPARK-5339 and squashes the following commits: 6e96121 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5339 0e012d1 [Kousuke Saruta] Updated Maven version to 3.2.5 ca26499 [Kousuke Saruta] Fixed URL of the tarball of Maven --- build/mvn | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build/mvn b/build/mvn index 43471f83e904c..f91e2b4bdcc02 100755 --- a/build/mvn +++ b/build/mvn @@ -68,10 +68,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder From 54e7b456dd56c9e52132154e699abca87563465b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 26 Jan 2015 14:23:42 -0800 Subject: [PATCH 30/41] SPARK-4147 [CORE] Reduce log4j dependency Defer use of log4j class until it's known that log4j 1.2 is being used. This may avoid dealing with log4j dependencies for callers that reroute slf4j to another logging framework. The only change is to push one half of the check in the original `if` condition inside. This is a trivial change, may or may not actually solve a problem, but I think it's all that makes sense to do for SPARK-4147. Author: Sean Owen Closes #4190 from srowen/SPARK-4147 and squashes the following commits: 4e99942 [Sean Owen] Defer use of log4j class until it's known that log4j 1.2 is being used. This may avoid dealing with log4j dependencies for callers that reroute slf4j to another logging framework. --- .../main/scala/org/apache/spark/Logging.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e35..419d093d55643 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true From b38034e878546a12c6d52f17fc961fd1a2453b97 Mon Sep 17 00:00:00 2001 From: "David Y. Ross" Date: Mon, 26 Jan 2015 14:26:10 -0800 Subject: [PATCH 31/41] Fix command spaces issue in make-distribution.sh Storing command in variables is tricky in bash, use an array to handle all issues with spaces, quoting, etc. See: http://mywiki.wooledge.org/BashFAQ/050 Author: David Y. Ross Closes #4126 from dyross/dyr-fix-make-distribution and squashes the following commits: 4ce522b [David Y. Ross] Fix command spaces issue in make-distribution.sh --- make-distribution.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be3053..0adca7851819b 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -115,7 +115,7 @@ if which git &>/dev/null; then unset GITREV fi -if ! which $MVN &>/dev/null; then +if ! which "$MVN" &>/dev/null; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; @@ -171,13 +171,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" From 0497ea51ac345f8057d222a18dbbf8eae78f5b92 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 26 Jan 2015 14:32:27 -0800 Subject: [PATCH 32/41] SPARK-960 [CORE] [TEST] JobCancellationSuite "two jobs sharing the same stage" is broken This reenables and fixes this test, after addressing two issues: - The Semaphore that was intended to be shared locally was being serialized and copied; it's now a static member in the companion object as in other tests - Later changes to Spark means that cancelling the first task will not cancel the shared stage and therefore the second task should succeed Author: Sean Owen Closes #4180 from srowen/SPARK-960 and squashes the following commits: 43da66f [Sean Owen] Fix 'two jobs sharing the same stage' test and reenable it: truly share a Semaphore locally as intended, and update expectation of failure in non-cancelled task --- .../org/apache/spark/JobCancellationSuite.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc920..21487bc24d58a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) }.reduceByKey(_+_) val f1 = rdd.collectAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } From 661e0fca5d5d86efab5fb26da600ac2ac96b09ec Mon Sep 17 00:00:00 2001 From: Elmer Garduno Date: Mon, 26 Jan 2015 17:40:48 -0800 Subject: [PATCH 33/41] [SPARK-5052] Add common/base classes to fix guava methods signatures. Fixes problems with incorrect method signatures related to shaded classes. For discussion see the jira issue. Author: Elmer Garduno Closes #3874 from elmer-garduno/fix_guava_signatures and squashes the following commits: aa5d8e0 [Elmer Garduno] Unshade common/base[Function|Supplier] classes to fix guava methods signatures. --- assembly/pom.xml | 2 ++ core/pom.xml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2b..594fa0c779e1b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -142,8 +142,10 @@ com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08afc..1984682b9c099 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,8 +372,10 @@ com.google.guava:guava com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier From f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 26 Jan 2015 18:03:21 -0800 Subject: [PATCH 34/41] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train... ... decision tree model Labels loaded from libsvm files are mapped to 0.0 if they are negative labels because they should be nonnegative value. Author: lewuathe Closes #3975 from Lewuathe/map-negative-label-to-positive and squashes the following commits: 12d1d59 [lewuathe] [SPARK-5119] Fix code styles 6d9a18a [lewuathe] [SPARK-5119] Organize test codes 62a150c [lewuathe] [SPARK-5119] Modify Impurities throw exceptions with negatie labels 3336c21 [lewuathe] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train decision tree model --- .../spark/mllib/tree/impurity/Entropy.scala | 5 +++ .../spark/mllib/tree/impurity/Gini.scala | 5 +++ .../spark/mllib/tree/ImpuritySuite.scala | 42 +++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa3774..b7950e00786ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a0..c946db9c0d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 0000000000000..92b498580af03 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} From d6894b1c5314c751cfdaf78005b99b2104e6e4d1 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 26 Jan 2015 19:46:17 -0800 Subject: [PATCH 35/41] [SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests I've added support for sampling_rate not equal to 1.0 . I have two major questions. 1. A Scala style test is failing, since the number of parameters now exceed 10. 2. I would like suggestions to understand how to test this. Author: MechCoder Closes #4073 from MechCoder/spark-3726 and squashes the following commits: 8012fb2 [MechCoder] Add test in Strategy e0e0d9c [MechCoder] TST: Add better test d1df1b2 [MechCoder] Add test to verify subsampling behavior a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0 --- .../apache/spark/mllib/tree/RandomForest.scala | 16 +++++----------- .../mllib/tree/configuration/Strategy.scala | 3 +++ .../spark/mllib/tree/RandomForestSuite.scala | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c650..482dd4b272d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -140,6 +140,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +156,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 972959885f396..3308adb6752ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -156,6 +156,9 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c125..55e963977b54f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } + + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + } From 7b0ed797958a91cda73baa7aa49ce66bfcb6b64b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Jan 2015 01:29:14 -0800 Subject: [PATCH 36/41] [SPARK-5419][Mllib] Fix the logic in Vectors.sqdist The current implementation in Vectors.sqdist is not efficient because of allocating temp arrays. There is also a bug in the code `v1.indices.length / v1.size < 0.5`. This pr fixes the bug and refactors sqdist without allocating new arrays. Author: Liang-Chi Hsieh Closes #4217 from viirya/fix_sqdist and squashes the following commits: e8b0b3d [Liang-Chi Hsieh] For review comments. 314c424 [Liang-Chi Hsieh] Fix sqdist bug. --- .../apache/spark/mllib/linalg/Vectors.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b3022add38469..2834ea75ceb8f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -371,18 +371,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } From 914267484a7156718ab6da37a6a42bbb074b51ac Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 27 Jan 2015 01:46:17 -0800 Subject: [PATCH 37/41] [SPARK-5321] Support for transposing local matrices Support for transposing local matrices added. The `.transpose` function creates a new object re-using the backing array(s) but switches `numRows` and `numCols`. Operations check the flag `.isTransposed` to see whether the indexing in `values` should be modified. This PR will pave the way for transposing `BlockMatrix`. Author: Burak Yavuz Closes #4109 from brkyvz/SPARK-5321 and squashes the following commits: 87ab83c [Burak Yavuz] fixed scalastyle caf4438 [Burak Yavuz] addressed code review v3 c524770 [Burak Yavuz] address code review comments 2 77481e8 [Burak Yavuz] fixed MiMa f1c1742 [Burak Yavuz] small refactoring ccccdec [Burak Yavuz] fixed failed test dd45c88 [Burak Yavuz] addressed code review a01bd5f [Burak Yavuz] [SPARK-5321] Fixed MiMa issues 2a63593 [Burak Yavuz] [SPARK-5321] fixed bug causing failed gemm test c55f29a [Burak Yavuz] [SPARK-5321] Support for transposing local matrices cleaned up c408c05 [Burak Yavuz] [SPARK-5321] Support for transposing local matrices added --- .../org/apache/spark/mllib/linalg/BLAS.scala | 170 +++------ .../apache/spark/mllib/linalg/Matrices.scala | 347 +++++++++++------- .../apache/spark/mllib/linalg/BLASSuite.scala | 72 ++-- .../linalg/BreezeMatrixConversionSuite.scala | 9 + .../spark/mllib/linalg/MatricesSuite.scala | 106 +++++- project/MimaExcludes.scala | 14 + 6 files changed, 436 insertions(+), 282 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca4..34e0392f1b21a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -257,80 +257,58 @@ private[spark] object BLAS extends Serializable with Logging { /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) + gemm(alpha, sparse, B, beta, C) case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +316,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +334,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +361,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +386,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +408,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +425,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +432,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +477,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3c..ad7e86827b368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] @@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() @@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } /** @@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable { * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) override def equals(o: Any) = o match { case m: DenseMatrix => @@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v @@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. */ def toSparse(): SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -271,49 +340,73 @@ object DenseMatrix { * @param rowIndices the row index of the entry. They must be in strictly increasing order for each * column * @param values non-zero matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - private[mllib] def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,7 +415,7 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } @@ -341,7 +434,38 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + override def transpose: Matrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. */ def toDense(): DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } @@ -557,10 +681,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -679,46 +802,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +849,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +862,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea7..b0b78acd6df16 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -169,16 +169,17 @@ class BLASSuite extends FunSuite { } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +189,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +207,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +252,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +272,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d868..2031032373971 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fdd..b1ebfde0e5e57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) + assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bc5d81f12d746..af0b0ebb9a383 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,6 +52,20 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( From ff356e2a21e31998cda3062e560a276a3bfaa7ab Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 27 Jan 2015 10:22:50 -0800 Subject: [PATCH 38/41] SPARK-5308 [BUILD] MD5 / SHA1 hash format doesn't match standard Maven output Here's one way to make the hashes match what Maven's plugins would create. It takes a little extra footwork since OS X doesn't have the same command line tools. An alternative is just to make Maven output these of course - would that be better? I ask in case there is a reason I'm missing, like, we need to hash files that Maven doesn't build. Author: Sean Owen Closes #4161 from srowen/SPARK-5308 and squashes the following commits: 70d09d0 [Sean Owen] Use $(...) syntax e25eff8 [Sean Owen] Generate MD5, SHA1 hashes in a format like Maven's plugin --- dev/create-release/create-release.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098b..b2a7e092a0291 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id From fdaad4eb0388cfe43b5b6600927eb7b9182646f9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 27 Jan 2015 15:33:01 -0800 Subject: [PATCH 39/41] [MLlib] fix python example of ALS in guide fix python example of ALS in guide, use Rating instead of np.array. Author: Davies Liu Closes #4226 from davies/fix_als_guide and squashes the following commits: 1433d76 [Davies Liu] fix python example of als in guide --- docs/mllib-collaborative-filtering.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 2094963392295..ef18cec9371d6 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,10 +204,10 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) {% endhighlight %} @@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %} From b1b35ca2e440df40b253bf967bb93705d355c1c0 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 27 Jan 2015 15:42:55 -0800 Subject: [PATCH 40/41] SPARK-5199. FS read metrics should support CombineFileSplits and track bytes from all FSs ...mbineFileSplits Author: Sandy Ryza Closes #4050 from sryza/sandy-spark-5199 and squashes the following commits: 864514b [Sandy Ryza] Add tests and fix bug 0d504f1 [Sandy Ryza] Prettify 915c7e6 [Sandy Ryza] Get metrics from all filesystems cdbc3e8 [Sandy Ryza] SPARK-5199. Input metrics should show up for InputFormats that return CombineFileSplits --- .../apache/spark/deploy/SparkHadoopUtil.scala | 16 ++- .../apache/spark/executor/TaskMetrics.scala | 1 - .../org/apache/spark/rdd/HadoopRDD.scala | 12 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 15 +-- .../apache/spark/rdd/PairRDDFunctions.scala | 11 +-- .../metrics/InputOutputMetricsSuite.scala | 97 ++++++++++++++++++- 6 files changed, 120 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1d..211e3ede53d9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() @@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() @@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging { } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ddb5903bf6875..97912c68c5982 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.executor import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.DataReadMethod import org.apache.spark.executor.DataReadMethod.DataReadMethod import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 056aef0bc210a..c3e3931042de2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -218,13 +219,13 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -254,7 +255,8 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b0e3c87ccff4..d86f95ac3e485 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -34,7 +34,7 @@ import org.apache.spark.Logging import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils @@ -114,13 +114,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -163,7 +163,8 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0f37d830ef34f..49b88a90ab5af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { @@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1086,11 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80ce..81db66ae17464 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -26,7 +26,16 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, + LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, + TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil @@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } } + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() +} \ No newline at end of file From 119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 27 Jan 2015 16:08:24 -0800 Subject: [PATCH 41/41] [SPARK-5097][SQL] DataFrame This pull request redesigns the existing Spark SQL dsl, which already provides data frame like functionalities. TODOs: With the exception of Python support, other tasks can be done in separate, follow-up PRs. - [ ] Audit of the API - [ ] Documentation - [ ] More test cases to cover the new API - [x] Python support - [ ] Type alias SchemaRDD Author: Reynold Xin Author: Davies Liu Closes #4173 from rxin/df1 and squashes the following commits: 0a1a73b [Reynold Xin] Merge branch 'df1' of github.com:rxin/spark into df1 23b4427 [Reynold Xin] Mima. 828f70d [Reynold Xin] Merge pull request #7 from davies/df 257b9e6 [Davies Liu] add repartition 6bf2b73 [Davies Liu] fix collect with UDT and tests e971078 [Reynold Xin] Missing quotes. b9306b4 [Reynold Xin] Remove removeColumn/updateColumn for now. a728bf2 [Reynold Xin] Example rename. e8aa3d3 [Reynold Xin] groupby -> groupBy. 9662c9e [Davies Liu] improve DataFrame Python API 4ae51ea [Davies Liu] python API for dataframe 1e5e454 [Reynold Xin] Fixed a bug with symbol conversion. 2ca74db [Reynold Xin] Couple minor fixes. ea98ea1 [Reynold Xin] Documentation & literal expressions. 2b22684 [Reynold Xin] Got rid of IntelliJ problems. 02bbfbc [Reynold Xin] Tightening imports. ffbce66 [Reynold Xin] Fixed compilation error. 59b6d8b [Reynold Xin] Style violation. b85edfb [Reynold Xin] ALS. 8c37f0a [Reynold Xin] Made MLlib and examples compile 6d53134 [Reynold Xin] Hive module. d35efd5 [Reynold Xin] Fixed compilation error. ce4a5d2 [Reynold Xin] Fixed test cases in SQL except ParquetIOSuite. 66d5ef1 [Reynold Xin] SQLContext minor patch. c9bcdc0 [Reynold Xin] Checkpoint: SQL module compiles! --- .../ml/JavaCrossValidatorExample.java | 10 +- .../examples/ml/JavaSimpleParamsExample.java | 12 +- .../JavaSimpleTextClassificationPipeline.java | 10 +- .../spark/examples/sql/JavaSparkSQL.java | 36 +- .../src/main/python/mllib/dataset_example.py | 2 +- examples/src/main/python/sql.py | 16 +- .../examples/ml/CrossValidatorExample.scala | 3 +- .../spark/examples/ml/MovieLensALS.scala | 2 +- .../examples/ml/SimpleParamsExample.scala | 5 +- .../ml/SimpleTextClassificationPipeline.scala | 3 +- .../spark/examples/mllib/DatasetExample.scala | 28 +- .../spark/examples/sql/RDDRelation.scala | 6 +- .../scala/org/apache/spark/ml/Estimator.scala | 8 +- .../scala/org/apache/spark/ml/Evaluator.scala | 4 +- .../scala/org/apache/spark/ml/Pipeline.scala | 6 +- .../org/apache/spark/ml/Transformer.scala | 17 +- .../classification/LogisticRegression.scala | 14 +- .../BinaryClassificationEvaluator.scala | 7 +- .../spark/ml/feature/StandardScaler.scala | 15 +- .../apache/spark/ml/recommendation/ALS.scala | 37 +- .../spark/ml/tuning/CrossValidator.scala | 8 +- .../apache/spark/ml/JavaPipelineSuite.java | 6 +- .../JavaLogisticRegressionSuite.java | 8 +- .../ml/tuning/JavaCrossValidatorSuite.java | 4 +- .../org/apache/spark/ml/PipelineSuite.scala | 14 +- .../LogisticRegressionSuite.scala | 16 +- .../spark/ml/recommendation/ALSSuite.scala | 4 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- project/MimaExcludes.scala | 15 +- python/pyspark/java_gateway.py | 7 +- python/pyspark/sql.py | 967 +++++++++++++----- python/pyspark/tests.py | 155 +-- .../analysis/MultiInstanceRelation.scala | 2 +- .../expressions/namedExpressions.scala | 3 + .../spark/sql/catalyst/plans/joinTypes.scala | 15 + .../catalyst/plans/logical/TestRelation.scala | 8 +- .../org/apache/spark/sql/CacheManager.scala | 8 +- .../scala/org/apache/spark/sql/Column.scala | 528 ++++++++++ .../org/apache/spark/sql/DataFrame.scala | 596 +++++++++++ .../apache/spark/sql/GroupedDataFrame.scala | 139 +++ .../scala/org/apache/spark/sql/Literal.scala | 98 ++ .../org/apache/spark/sql/SQLContext.scala | 85 +- .../org/apache/spark/sql/SchemaRDD.scala | 511 --------- .../org/apache/spark/sql/SchemaRDDLike.scala | 139 --- .../main/scala/org/apache/spark/sql/api.scala | 289 ++++++ .../org/apache/spark/sql/dsl/package.scala | 495 +++++++++ .../apache/spark/sql/execution/commands.scala | 8 +- .../spark/sql/execution/debug/package.scala | 4 +- .../scala/org/apache/spark/sql/package.scala | 2 +- .../spark/sql/parquet/ParquetTest.scala | 6 +- .../sql/sources/DataSourceStrategy.scala | 32 +- .../org/apache/spark/sql/sources/ddl.scala | 5 +- .../spark/sql/test/TestSQLContext.scala | 6 +- .../spark/sql/api/java/JavaAPISuite.java | 4 +- .../sql/api/java/JavaApplySchemaSuite.java | 16 +- .../apache/spark/sql/CachedTableSuite.scala | 1 + .../org/apache/spark/sql/DslQuerySuite.scala | 119 +-- .../org/apache/spark/sql/JoinSuite.scala | 67 +- .../org/apache/spark/sql/QueryTest.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 +- .../scala/org/apache/spark/sql/TestData.scala | 23 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 11 +- .../spark/sql/UserDefinedTypeSuite.scala | 6 +- .../columnar/InMemoryColumnarQuerySuite.scala | 1 + .../spark/sql/execution/PlannerSuite.scala | 14 +- .../apache/spark/sql/execution/TgfSuite.scala | 65 -- .../org/apache/spark/sql/json/JsonSuite.scala | 11 +- .../sql/parquet/ParquetFilterSuite.scala | 126 +-- .../spark/sql/parquet/ParquetIOSuite.scala | 7 +- .../spark/sql/sources/PrunedScanSuite.scala | 2 + .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../spark/sql/hive/thriftserver/Shim12.scala | 6 +- .../spark/sql/hive/thriftserver/Shim13.scala | 6 +- .../apache/spark/sql/hive/HiveContext.scala | 9 +- .../spark/sql/hive/HiveStrategies.scala | 17 +- .../org/apache/spark/sql/hive/TestHive.scala | 9 +- .../org/apache/spark/sql/QueryTest.scala | 10 +- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 7 +- .../hive/execution/HiveTableScanSuite.scala | 11 +- .../sql/hive/execution/HiveUdfSuite.scala | 2 +- 82 files changed, 3444 insertions(+), 1572 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Column.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Literal.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8c..0fbee6e433608 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,7 +33,7 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + cvModel.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e838..eaaa344be49c8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -48,13 +48,13 @@ public static void main(String[] args) { // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = + model2.transform(test).registerTempTable("results"); + DataFrame results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7d..82d665a3e1386 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + model.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c9..8defb769ffaaf 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -74,13 +74,13 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override @@ -93,17 +93,17 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,13 +130,13 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6ea..b5a70db2b9a3c 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,7 +16,7 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb8..7f5c68e3d0fe2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -30,18 +30,18 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.inferSchema(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.applySchema(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) @@ -49,7 +49,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,7 +61,7 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..283bb80f1c788 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator @@ -101,7 +100,7 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cf62772b92651..b7885829459a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -143,7 +143,7 @@ object MovieLensALS { // Evaluate the model. // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select('rating, 'prediction) + val mse = predictions.select("rating", "prediction").rdd .flatMap { case Row(rating: Float, prediction: Float) => val err = rating.toDouble - prediction val err2 = err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..95cc9801eaeb9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -42,7 +41,7 @@ object SimpleParamsExample { // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. val training = sparkContext.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -92,7 +91,7 @@ object SimpleParamsExample { // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + .select("features", "label", "probability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..065db62b0f5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} @@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..f229a58985a3e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDF + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns, using implicit conversion to DataFrames. + val labelsDf: DataFrame = origData.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = origData.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..a5d7f262581f5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,8 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.dsl._ +import org.apache.spark.sql.dsl.literals._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -54,7 +56,7 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. rdd.saveAsParquetFile("pair.parquet") @@ -63,7 +65,7 @@ object RDDRelation { val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a122..bc3defe968afd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMap parameter map * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMaps an array of parameter maps * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..fe39cd1bc0bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -162,7 +162,7 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c435351..b233bff08305c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,9 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types._ /** @@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.select($"*", callUDF( + this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..eeb6301c3f64a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val score: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) @@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] ( val predict: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..1979ab9eb6516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** @@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params def setScoreCol(value: String): this.type = set(scoreCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema @@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params require(labelType == DoubleType, s"Label column ${map(labelCol)} must be double type but found $labelType") - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) .map { case Row(score: Double, label: Double) => (score, label) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..e7bdb070c8193 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} @@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) @@ -83,14 +79,13 @@ class StandardScalerModel private[ml] ( def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2d89e76a4c8b2..f6437c7fbc8ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -112,7 +110,7 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext._ import org.apache.spark.ml.recommendation.ALSModel.Factor val map = this.paramMap ++ paramMap @@ -120,13 +118,13 @@ class ALSModel private[ml] ( val instanceTable = s"instance_$uid" val userTable = s"user_$uid" val itemTable = s"item_$uid" - val instances = dataset.as(Symbol(instanceTable)) + val instances = dataset.as(instanceTable) val users = userFactors.map { case (id, features) => Factor(id, features) - }.as(Symbol(userTable)) + }.as(userTable) val items = itemFactors.map { case (id, features) => Factor(id, features) - }.as(Symbol(itemTable)) + }.as(itemTable) val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -135,12 +133,12 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = - predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) - val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction + val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features") + .as(map(predictionCol)) + val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction instances - .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) - .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + .join(users, Column(map(userCol)) === $"$userTable.id", "left") + .join(items, Column(map(itemCol)) === $"$itemTable.id", "left") .select(outputColumns: _*) } @@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { - import dataset.sqlContext._ + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap - val ratings = - dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) - .map { row => - new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) - } + val ratings = dataset + .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..5d51c51346665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setEvaluator(value: Evaluator): this.type = set(evaluator, value) def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.applySchema(training, schema).cache() val validationDataset = sqlCtx.applySchema(validation, schema).cache() @@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c260..56a9dbdd58b64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb58..f4ba23c44563e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -55,7 +55,7 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } @@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca7..074b58c07df7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..1912afce93b18 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -21,12 +21,12 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "prediction") .collect() } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + .select("label", "score", "prediction") .collect() } test("logistic regression fit and transform with varargs") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) + .select("label", "probability", "prediction") .collect() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index cdd4db1b5b7dc..58289acdbc095 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + import sqlContext.createSchemaRDD val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val alpha = als.getAlpha val model = als.fit(training) val predictions = model.transform(test) - .select('rating, 'prediction) + .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..74104fa7a681a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index af0b0ebb9a383..e750fed7448cd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -95,7 +95,20 @@ object MimaExcludes { ) ++ Seq( // SPARK-5166 Spark SQL API stabilization ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") ) ++ Seq( // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..a0a028446d5fd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -111,10 +111,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1990323249cf6..7d7550c854b2f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,15 +20,19 @@ - L{SQLContext} Main entry point for SQL functionality. - - L{SchemaRDD} + - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedDataFrame} + - L{Column} + Column is a DataFrame with a single column. - L{Row} A Row of data returned by a Spark SQL query. - L{HiveContext} Main entry point for accessing data stored in Apache Hive.. """ +import sys import itertools import decimal import datetime @@ -36,6 +40,9 @@ import warnings import json import re +import random +import os +from tempfile import NamedTemporaryFile from array import array from operator import itemgetter from itertools import imap @@ -43,6 +50,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter +from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer @@ -54,7 +62,8 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SchemaRDD"] class DataType(object): @@ -1171,7 +1180,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1198,7 +1207,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... @@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None): ... [Row(field1=1, field2="row1"), ... Row(field1=2, field2="row2"), ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') >>> NestedRow = Row("f1", "f2") >>> nestedRdd1 = sc.parallelize([ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> nestedRdd2 = sc.parallelize([ ... NestedRow([[1, 2], [2, 3]], [1, 2]), ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] >>> from collections import namedtuple @@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None): ... [CustomRow(field1=1, field2="row1"), ... CustomRow(field1=2, field2="row2"), ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") first = rdd.first() if not first: @@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema): >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import date, datetime @@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> srdd.registerTempTable("table2") + >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + @@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema): >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema): rdd = rdd.map(converter) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName): Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) else: - raise ValueError("Can only register SchemaRDD as table") + raise ValueError("Can only register DataFrame as table") def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. + """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) + jdf = self._ssql_ctx.parquetFile(path) + return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a - L{SchemaRDD}. + L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) @@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) + df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. + """Loads an RDD storing one JSON object per string as a L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Otherwise, it samples the dataset with ratio `samplingRatio` to determine the schema. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) @@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] >>> sqlCtx.jsonRDD(sc.parallelize(['{}', @@ -1615,33 +1624,33 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. + """Return a L{DataFrame} representing the result of the given query. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. + """Returns the specified table as a L{DataFrame}. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return DataFrame(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1707,7 +1716,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1799,111 +1808,119 @@ def inherit_doc(cls): return cls -@inherit_doc -class SchemaRDD(RDD): +class DataFrame(object): - """An RDD of L{Row} objects that has an associated schema. + """A collection of rows that have the same columns. - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. + people = sqlContext.parquetFile("...") - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: [[DataFrame]], [[Column]]. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ - def __init__(self, jschema_rdd, sql_ctx): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None + self._sc = sql_ctx and sql_ctx._sc self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. + def rdd(self): + """Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row`s. """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd def limit(self, num): """Limit the result count to the number specified. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.limit(2).collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() + >>> df.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( "SELECT * from table1") + >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] True """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. + a DataFrame using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. + that was used to create this DataFrame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.registerTempTable("test") + >>> df2 = sqlCtx.sql("select * from test") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - self._jschema_rdd.registerTempTable(name) + self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use registerTempTable() instead""" @@ -1911,62 +1928,61 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. + """Inserts the contents of this DataFrame into the specified table. Optionally overwriting any existing data. """ - self._jschema_rdd.insertInto(tableName, overwrite) + self._jdf.insertInto(tableName, overwrite) def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by + """Returns the schema of this DataFrame (represented by a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() + return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): """Prints out the schema in the tree format.""" - print self.schemaString() + print (self._jdf.schema().treeString()) def count(self): """Return the number of elements in this RDD. Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, + leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.count() 3L - >>> srdd.count() == srdd.map(lambda x: x).count() + >>> df.count() == df.map(lambda x: x).count() True """ - return self._jschema_rdd.count() + return self._jdf.count() def collect(self): - """Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows. Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect() [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + return [cls(r) for r in rs] def take(self, num): """Take the first num rows of the RDD. @@ -1974,130 +1990,555 @@ def take(self, num): Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.take(2) [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return self.limit(num).collect() - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. + return self.rdd.map(f) - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 + def mapPartitions(self, f, preservesPartitioning=False): """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() + Return a new RDD by applying a function to each partition. - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - self._jschema_rdd.cache() + self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ self.is_cached = False - self._jschema_rdd.unpersist(blocking) + self._jdf.unpersist(blocking) return self - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() + @property + def dtypes(self): + """Return all column names and their data types as a list. + """ + return [(f.name, str(f.dataType)) for f in self.schema().fields] + + @property + def columns(self): + """ Return all column names as a list. + """ + return [f.name for f in self.schema().fields] + + def show(self): + raise NotImplemented + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + df1.join(df2, df1.key == df2.key, "outer") + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, + `semijoin`. + """ + if joinType is None: + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + jdf = self._jdf.join(other._jdf, joinExprs) else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.join(other._jdf, joinExprs, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new [[DataFrame]] sorted by the specified column, + in ascending column. + + :param cols: The columns or expressions used for sorting + """ + if not cols: + raise ValueError("should sort by at least one column") + for i, c in enumerate(cols): + if isinstance(c, basestring): + cols[i] = Column(c) + jcols = [c._jc for c in cols] + jdf = self._jdf.join(*jcols) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def tail(self): + raise NotImplemented + + def __getitem__(self, item): + if isinstance(item, basestring): + return Column(self._jdf.apply(item)) + + # TODO projection + raise IndexError + + def __getattr__(self, name): + """ Return the column by given name """ + if isinstance(name, basestring): + return Column(self._jdf.apply(name)) + raise AttributeError + + def As(self, name): + """ Alias the current DataFrame """ + return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions.:: + + df.select() + df.select('colA', 'colB') + df.select(df.colA, df.colB + 1) - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) + """ + if not cols: + cols = ["*"] + if isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only intersect with another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) + return DataFrame(jdf, self.sql_ctx) - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + def filter(self, condition): + """ Filtering rows using the given condition:: - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) + df.filter(df.age > 15) + df.where(df.age > 15) + + """ + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the [[DataFrame]] using the specified columns, + so we can run aggregation on them. See :class:`GroupedDataFrame` + for all the available aggregate functions:: + + df.groupBy(df.department).avg() + df.groupBy("department", "gender").agg({ + "salary": "avg", + "age": "max", + }) + """ + if cols and isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only subtract another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) + return GroupedDataFrame(jdf, self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): + def agg(self, *exprs): + """ Aggregate on the entire [[DataFrame]] without groups + (shorthand for df.groupBy.agg()):: + + df.agg({"age": "max", "salary": "avg"}) """ - Return a sampled subset of this SchemaRDD. + return self.groupBy().agg(*exprs) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. + def intersect(self, other): + """ Return a new [[DataFrame]] containing rows only in + both this frame and another frame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + This is equivalent to `INTERSECT` in SQL. """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def Except(self, other): + """ Return a new [[DataFrame]] containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ Return a new DataFrame by sampling a fraction of rows. """ + if seed is None: + jdf = self._jdf.sample(withReplacement, fraction) + else: + jdf = self._jdf.sample(withReplacement, fraction, seed) + return DataFrame(jdf, self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new [[DataFrame]] by adding a column. """ + return self.select('*', col.As(colName)) + + def removeColumn(self, colName): + raise NotImplemented + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedDataFrame(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to agregate methods. + """ + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" + jdf = self._jdf.agg(*exprs) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +SCALA_METHOD_MAPPINGS = { + '=': '$eq', + '>': '$greater', + '<': '$less', + '+': '$plus', + '-': '$minus', + '*': '$times', + '/': '$div', + '!': '$bang', + '@': '$at', + '#': '$hash', + '%': '$percent', + '^': '$up', + '&': '$amp', + '~': '$tilde', + '?': '$qmark', + '|': '$bar', + '\\': '$bslash', + ':': '$colon', +} + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Literal.apply(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Column(name) + + +def _scalaMethod(name): + """ Translate operators into methodName in Scala + + For example: + >>> _scalaMethod('+') + '$plus' + >>> _scalaMethod('>=') + '$greater$eq' + >>> _scalaMethod('cast') + 'cast' + """ + return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) + + +def _unary_op(name): + """ Create a method for given unary operator """ + def _(self): + return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + return _ + + +def _bin_op(name): + """ Create a method for given binary operator """ + def _(self, other): + if isinstance(other, Column): + jc = other._jc + else: + jc = _create_column_from_literal(other) + return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + return _ + + +def _reverse_op(name): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), + self._jdf, self.sql_ctx) + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by: + {{{ + // 1. Select a column out of a DataFrame + df.colName + df["colName"] + + // 2. Create from an expression + df["colName"] + 1 + }}} + """ + + def __init__(self, jc, jdf=None, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jdf, sql_ctx) + + # arithmetic operators + __neg__ = _unary_op("unary_-") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+") + __rsub__ = _reverse_op("-") + __rmul__ = _bin_op("*") + __rdiv__ = _reverse_op("/") + __rmod__ = _reverse_op("%") + __abs__ = _unary_op("abs") + abs = _unary_op("abs") + sqrt = _unary_op("sqrt") + + # logistic operators + __eq__ = _bin_op("===") + __ne__ = _bin_op("!==") + __lt__ = _bin_op("<") + __le__ = _bin_op("<=") + __ge__ = _bin_op(">=") + __gt__ = _bin_op(">") + # `and`, `or`, `not` cannot be overloaded in Python + And = _bin_op('&&') + Or = _bin_op('||') + Not = _unary_op('unary_!') + + # bitwise operators + __and__ = _bin_op("&") + __or__ = _bin_op("|") + __invert__ = _unary_op("unary_~") + __xor__ = _bin_op("^") + # __lshift__ = _bin_op("<<") + # __rshift__ = _bin_op(">>") + __rand__ = _bin_op("&") + __ror__ = _bin_op("|") + __rxor__ = _bin_op("^") + # __rlshift__ = _reverse_op("<<") + # __rrshift__ = _reverse_op(">>") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + # __getattr__ = _bin_op("getField") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + upper = _unary_op("upper") + lower = _unary_op("lower") + + def substr(self, startPos, pos): + if type(startPos) != type(pos): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + + jc = self._jc.substr(startPos, pos) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, pos._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self._jdf, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull") + isNotNull = _unary_op("isNotNull") + + # `as` is keyword + def As(self, alias): + return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + + def cast(self, dataType): + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + jdt = ssql_ctx.parseDataType(dataType.json()) + return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + + +def _aggregate_func(name): + """ Creat a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + # FIXME: can not access dsl.min/max ... + jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol) + return Column(jc) + return staticmethod(_) + + +class Aggregator(object): + """ + A collections of builtin aggregators + """ + max = _aggregate_func("max") + min = _aggregate_func("min") + avg = mean = _aggregate_func("mean") + sum = _aggregate_func("sum") + first = _aggregate_func("first") + last = _aggregate_func("last") + count = _aggregate_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..e8e207af462de 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -806,6 +806,9 @@ def tearDownClass(cls): def setUp(self): self.sqlCtx = SQLContext(self.sc) + self.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(self.testData) + self.df = self.sqlCtx.inferSchema(rdd) def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) @@ -821,7 +824,7 @@ def test_udf2(self): def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.inferSchema(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -839,68 +842,51 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() - - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() + df = self.sqlCtx.inferSchema(rdd) + row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = srdd.map(lambda x: x.l).first() + l = df.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = srdd.map(lambda x: x.d).first() + d = df.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = srdd.map(lambda x: x.d["key"]).first() + row = df.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -908,26 +894,26 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -935,9 +921,9 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -945,12 +931,12 @@ def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -959,21 +945,52 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.inferSchema(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbit)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg + # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 22941edef2d46..4c5fb3f45bf49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] { .toSet plan transform { - case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3035d934ff9f8..f388cd5972bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression { * For example the SQL expression "1 + 1 AS a" could be represented as follows: * Alias(Add(Literal(1), Literal(1), "a")() * + * Note that exprId and qualifiers are in a separate parameter list because + * we only pattern match on child and name. + * * @param child the computation being performed * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 613f4bb09daf5..5dc0539caec24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,9 +17,24 @@ package org.apache.spark.sql.catalyst.plans +object JoinType { + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + case "inner" => Inner + case "outer" | "full" | "fullouter" => FullOuter + case "leftouter" | "left" => LeftOuter + case "rightouter" | "right" => RightOuter + case "leftsemi" => LeftSemi + } +} + sealed abstract class JoinType + case object Inner extends JoinType + case object LeftOuter extends JoinType + case object RightOuter extends JoinType + case object FullOuter extends JoinType + case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index 19769986ef58c..d90af45b375e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { - def apply(output: Attribute*) = - new LocalRelation(output) + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( + StructType(output1 +: output).toAttributes + ) } case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index e715d9434a2ab..bc22f688338b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -80,7 +80,7 @@ private[sql] trait CacheManager { * the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: SchemaRDD, + query: DataFrame, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -100,7 +100,7 @@ private[sql] trait CacheManager { } /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { + private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -110,7 +110,7 @@ private[sql] trait CacheManager { /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ private[sql] def tryUncacheQuery( - query: SchemaRDD, + query: DataFrame, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,7 +123,7 @@ private[sql] trait CacheManager { } /** Optionally returns cached data for the given SchemaRDD */ - private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 0000000000000..7fc8347428df4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,528 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.types._ + + +object Column { + def unapply(col: Column): Option[Expression] = Some(col.expr) + + def apply(colName: String): Column = new Column(colName) +} + + +/** + * A column in a [[DataFrame]]. + * + * `Column` instances can be created by: + * {{{ + * // 1. Select a column out of a DataFrame + * df("colName") + * + * // 2. Create a literal expression + * Literal(1) + * + * // 3. Create new columns from + * }}} + * + */ +// TODO: Improve documentation. +class Column( + sqlContext: Option[SQLContext], + plan: Option[LogicalPlan], + val expr: Expression) + extends DataFrame(sqlContext, plan) with ExpressionApi { + + /** Turn a Catalyst expression into a `Column`. */ + protected[sql] def this(expr: Expression) = this(None, None, expr) + + /** + * Create a new `Column` expression based on a column or attribute name. + * The resolution of this is the same as SQL. For example: + * + * - "colName" becomes an expression selecting the column named "colName". + * - "*" becomes an expression selecting all columns. + * - "df.*" becomes an expression selecting all columns in data frame "df". + */ + def this(name: String) = this(name match { + case "*" => Star(None) + case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined + + /** + * An implicit conversion function internal to this class. This function creates a new Column + * based on an expression. If the expression itself is not named, it aliases the expression + * by calling it "col". + */ + private[this] implicit def toColumn(expr: Expression): Column = { + val projectedPlan = plan.map { p => + Project(Seq(expr match { + case named: NamedExpression => named + case unnamed: Expression => Alias(unnamed, "col")() + }), p) + } + new Column(sqlContext, projectedPlan, expr) + } + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df("amount") ) + * }}} + */ + override def unary_- : Column = UnaryMinus(expr) + + /** + * Bitwise NOT. + * {{{ + * // Select the flags column and negate every bit. + * df.select( ~df("flags") ) + * }}} + */ + override def unary_~ : Column = BitwiseNot(expr) + + /** + * Invert a boolean expression, i.e. NOT. + * {{ + * // Select rows that are not active (isActive === false) + * df.select( !df("isActive") ) + * }} + */ + override def unary_! : Column = Not(expr) + + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def === (other: Column): Column = EqualTo(expr, other.expr) + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def === (literal: Any): Column = this === Literal.anyToLiteral(literal) + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def equalTo(other: Column): Column = this === other + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def equalTo(literal: Any): Column = this === literal + + /** + * Inequality test with an expression. + * {{{ + * // The following two both select rows in which colA does not equal colB. + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * }}} + */ + override def !== (other: Column): Column = Not(EqualTo(expr, other.expr)) + + /** + * Inequality test with a literal value. + * {{{ + * // The following two both select rows in which colA does not equal equal 15. + * df.select( df("colA") !== 15 ) + * df.select( !(df("colA") === 15) ) + * }}} + */ + override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal) + + /** + * Greater than an expression. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > Literal(21) ) + * }}} + */ + override def > (other: Column): Column = GreaterThan(expr, other.expr) + + /** + * Greater than a literal value. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > 21 ) + * }}} + */ + override def > (literal: Any): Column = this > Literal.anyToLiteral(literal) + + /** + * Less than an expression. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < Literal(21) ) + * }}} + */ + override def < (other: Column): Column = LessThan(expr, other.expr) + + /** + * Less than a literal value. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < 21 ) + * }}} + */ + override def < (literal: Any): Column = this < Literal.anyToLiteral(literal) + + /** + * Less than or equal to an expression. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= Literal(21) ) + * }}} + */ + override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr) + + /** + * Less than or equal to a literal value. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * }}} + */ + override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal) + + /** + * Greater than or equal to an expression. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= Literal(21) ) + * }}} + */ + override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr) + + /** + * Greater than or equal to a literal value. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * }}} + */ + override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal) + + /** + * Equality test with an expression that is safe for null values. + */ + override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr) + + /** + * Equality test with a literal value that is safe for null values. + */ + override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal) + + /** + * True if the current expression is null. + */ + override def isNull: Column = IsNull(expr) + + /** + * True if the current expression is NOT null. + */ + override def isNotNull: Column = IsNotNull(expr) + + /** + * Boolean OR with an expression. + * {{{ + * // The following selects people that are in school or employed. + * people.select( people("inSchool") || people("isEmployed") ) + * }}} + */ + override def || (other: Column): Column = Or(expr, other.expr) + + /** + * Boolean OR with a literal value. + * {{{ + * // The following selects everything. + * people.select( people("inSchool") || true ) + * }}} + */ + override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal) + + /** + * Boolean AND with an expression. + * {{{ + * // The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * }}} + */ + override def && (other: Column): Column = And(expr, other.expr) + + /** + * Boolean AND with a literal value. + * {{{ + * // The following selects people that are in school. + * people.select( people("inSchool") && true ) + * }}} + */ + override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal) + + /** + * Bitwise AND with an expression. + */ + override def & (other: Column): Column = BitwiseAnd(expr, other.expr) + + /** + * Bitwise AND with a literal value. + */ + override def & (literal: Any): Column = this & Literal.anyToLiteral(literal) + + /** + * Bitwise OR with an expression. + */ + override def | (other: Column): Column = BitwiseOr(expr, other.expr) + + /** + * Bitwise OR with a literal value. + */ + override def | (literal: Any): Column = this | Literal.anyToLiteral(literal) + + /** + * Bitwise XOR with an expression. + */ + override def ^ (other: Column): Column = BitwiseXor(expr, other.expr) + + /** + * Bitwise XOR with a literal value. + */ + override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * }}} + */ + override def + (other: Column): Column = Add(expr, other.expr) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and 10. + * people.select( people("height") + 10 ) + * }}} + */ + override def + (literal: Any): Column = this + Literal.anyToLiteral(literal) + + /** + * Subtraction. Substract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * }}} + */ + override def - (other: Column): Column = Subtract(expr, other.expr) + + /** + * Subtraction. Substract a literal value from this expression. + * {{{ + * // The following selects a person's height and substract it by 10. + * people.select( people("height") - 10 ) + * }}} + */ + override def - (literal: Any): Column = this - Literal.anyToLiteral(literal) + + /** + * Multiply this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * }}} + */ + override def * (other: Column): Column = Multiply(expr, other.expr) + + /** + * Multiply this expression and a literal value. + * {{{ + * // The following multiplies a person's height by 10. + * people.select( people("height") * 10 ) + * }}} + */ + override def * (literal: Any): Column = this * Literal.anyToLiteral(literal) + + /** + * Divide this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * }}} + */ + override def / (other: Column): Column = Divide(expr, other.expr) + + /** + * Divide this expression by a literal value. + * {{{ + * // The following divides a person's height by 10. + * people.select( people("height") / 10 ) + * }}} + */ + override def / (literal: Any): Column = this / Literal.anyToLiteral(literal) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (other: Column): Column = Remainder(expr, other.expr) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (literal: Any): Column = this % Literal.anyToLiteral(literal) + + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + */ + @scala.annotation.varargs + override def in(list: Column*): Column = In(expr, list.map(_.expr)) + + override def like(other: Column): Column = Like(expr, other.expr) + + override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal)) + + override def rlike(other: Column): Column = RLike(expr, other.expr) + + override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal)) + + + override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) + + override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr) + + override def getField(fieldName: String): Column = GetField(expr, fieldName) + + + override def substr(startPos: Column, len: Column): Column = + Substring(expr, startPos.expr, len.expr) + + override def substr(startPos: Int, len: Int): Column = + this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len)) + + override def contains(other: Column): Column = Contains(expr, other.expr) + + override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal)) + + + override def startsWith(other: Column): Column = StartsWith(expr, other.expr) + + override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal)) + + override def endsWith(other: Column): Column = EndsWith(expr, other.expr) + + override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal)) + + override def as(alias: String): Column = Alias(expr, alias)() + + override def cast(to: DataType): Column = Cast(expr, to) + + override def desc: Column = SortOrder(expr, Descending) + + override def asc: Column = SortOrder(expr, Ascending) +} + + +class ColumnName(name: String) extends Column(name) { + + /** Creates a new AttributeReference of type boolean */ + def boolean: StructField = StructField(name, BooleanType) + + /** Creates a new AttributeReference of type byte */ + def byte: StructField = StructField(name, ByteType) + + /** Creates a new AttributeReference of type short */ + def short: StructField = StructField(name, ShortType) + + /** Creates a new AttributeReference of type int */ + def int: StructField = StructField(name, IntegerType) + + /** Creates a new AttributeReference of type long */ + def long: StructField = StructField(name, LongType) + + /** Creates a new AttributeReference of type float */ + def float: StructField = StructField(name, FloatType) + + /** Creates a new AttributeReference of type double */ + def double: StructField = StructField(name, DoubleType) + + /** Creates a new AttributeReference of type string */ + def string: StructField = StructField(name, StringType) + + /** Creates a new AttributeReference of type date */ + def date: StructField = StructField(name, DateType) + + /** Creates a new AttributeReference of type decimal */ + def decimal: StructField = StructField(name, DecimalType.Unlimited) + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** Creates a new AttributeReference of type timestamp */ + def timestamp: StructField = StructField(name, TimestampType) + + /** Creates a new AttributeReference of type binary */ + def binary: StructField = StructField(name, BinaryType) + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala new file mode 100644 index 0000000000000..d0bb3640f8c1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -0,0 +1,596 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import java.util.{ArrayList, List => JList} + +import com.fasterxml.jackson.core.JsonFactory +import net.razorvine.pickle.Pickler + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.util.Utils + + +/** + * A collection of rows that have the same columns. + * + * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using + * various functions in [[SQLContext]]. + * {{{ + * val people = sqlContext.parquetFile("...") + * }}} + * + * Once created, it can be manipulated using the various domain-specific-language (DSL) functions + * defined in: [[DataFrame]] (this class), [[Column]], and [[dsl]] for Scala DSL. + * + * To select a column from the data frame, use the apply method: + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.apply("age") // in Java + * }}} + * + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * }} + * + * A more concrete example: + * {{{ + * // To create DataFrame using SQLContext + * val people = sqlContext.parquetFile("...") + * val department = sqlContext.parquetFile("...") + * + * people.filter("age" > 30) + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + */ +// TODO: Improve documentation. +class DataFrame protected[sql]( + val sqlContext: SQLContext, + private val baseLogicalPlan: LogicalPlan, + operatorsEnabled: Boolean) + extends DataFrameSpecificApi with RDDApi[Row] { + + protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) = + this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined) + + protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true) + + @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + baseLogicalPlan + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + private[this] implicit def toDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan, true) + } + + /** Return the list of numeric columns, useful for doing aggregation. */ + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + logicalPlan.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + /** Resolve a column name into a Catalyst [[NamedExpression]]. */ + protected[sql] def resolve(colName: String): NamedExpression = { + logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse( + throw new RuntimeException(s"""Cannot resolve column name "$colName"""")) + } + + /** Left here for compatibility reasons. */ + @deprecated("1.3.0", "use toDataFrame") + def toSchemaRDD: DataFrame = this + + /** + * Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. + */ + def toDF: DataFrame = this + + /** Return the schema of this [[DataFrame]]. */ + override def schema: StructType = queryExecution.analyzed.schema + + /** Return all column names and their data types as an array. */ + override def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** Return all column names as an array. */ + override def columns: Array[String] = schema.fields.map(_.name) + + /** Print the schema to the console in a nice tree format. */ + override def printSchema(): Unit = println(schema.treeString) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + */ + override def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + */ + override def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) + } + + /** + * Join with another [[DataFrame]], usin g the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + */ + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + /** + * Return a new [[DataFrame]] sorted by the specified column, in ascending column. + * {{{ + * // The following 3 are equivalent + * df.sort("sortcol") + * df.sort($"sortcol") + * df.sort($"sortcol".asc) + * }}} + */ + override def sort(colName: String): DataFrame = { + Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan) + } + + /** + * Return a new [[DataFrame]] sorted by the given expressions. For example: + * {{{ + * df.sort($"col1", $"col2".desc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + /** + * Return a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { + sort(sortExpr, sortExprs :_*) + } + + /** + * Selecting a single column and return it as a [[Column]]. + */ + override def apply(colName: String): Column = { + val expr = resolve(colName) + new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) + } + + /** + * Selecting a set of expressions, wrapped in a Product. + * {{{ + * // The following two are equivalent: + * df.apply(($"colA", $"colB" + 1)) + * df.select($"colA", $"colB" + 1) + * }}} + */ + override def apply(projection: Product): DataFrame = { + require(projection.productArity >= 1) + select(projection.productIterator.map { + case c: Column => c + case o: Any => new Column(Some(sqlContext), None, LiteralExpr(o)) + }.toSeq :_*) + } + + /** + * Alias the current [[DataFrame]]. + */ + override def as(name: String): DataFrame = Subquery(name, logicalPlan) + + /** + * Selecting a set of expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + */ + @scala.annotation.varargs + override def select(cols: Column*): DataFrame = { + val exprs = cols.zipWithIndex.map { + case (Column(expr: NamedExpression), _) => + expr + case (Column(expr: Expression), _) => + Alias(expr, expr.toString)() + } + Project(exprs.toSeq, logicalPlan) + } + + /** + * Selecting a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * df.select("colA", "colB") + * df.select($"colA", $"colB") + * }}} + */ + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(new Column(_)) :_*) + } + + /** + * Filtering rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def filter(condition: Column): DataFrame = { + Filter(condition.expr, logicalPlan) + } + + /** + * Filtering rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def where(condition: Column): DataFrame = filter(condition) + + /** + * Filtering rows using the given condition. This is a shorthand meant for Scala. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def apply(condition: Column): DataFrame = filter(condition) + + /** + * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(cols: Column*): GroupedDataFrame = { + new GroupedDataFrame(this, cols.map(_.expr)) + } + + /** + * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + val colNames: Seq[String] = col1 +: cols + new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + } + + /** + * Aggregate on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregate on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) + * }} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + + /** + * Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + */ + override def limit(n: Int): DataFrame = Limit(LiteralExpr(n), logicalPlan) + + /** + * Return a new [[DataFrame]] containing union of rows in this frame and another frame. + * This is equivalent to `UNION ALL` in SQL. + */ + override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] containing rows only in both this frame and another frame. + * This is equivalent to `INTERSECT` in SQL. + */ + override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] containing rows in this frame but not in another frame. + * This is equivalent to `EXCEPT` in SQL. + */ + override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + */ + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + /** + * Return a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + */ + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Return a new [[DataFrame]] by adding a column. + */ + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + /** + * Return the first `n` rows. + */ + override def head(n: Int): Array[Row] = limit(n).collect() + + /** + * Return the first row. + */ + override def head(): Row = head(1).head + + /** + * Return the first row. Alias for head(). + */ + override def first(): Row = head() + + override def map[R: ClassTag](f: Row => R): RDD[R] = { + rdd.map(f) + } + + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Return the first `n` rows in the [[DataFrame]]. + */ + override def take(n: Int): Array[Row] = head(n) + + /** + * Return an array that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collect(): Array[Row] = rdd.collect() + + /** + * Return a Java list that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + /** + * Return the number of rows in the [[DataFrame]]. + */ + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + /** + * Return a new [[DataFrame]] that has exactly `numPartitions` partitions. + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } + + override def persist(): this.type = { + sqlContext.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheQuery(this, None, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.tryUncacheQuery(this, blocking) + this + } + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + /** + * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s. + */ + override def rdd: RDD[Row] = { + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * + * @group schema + */ + override def registerTempTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. + * Files that are written out using this method can be read back in as a [[DataFrame]] + * using the `parquetFile` function in [[SQLContext]]. + */ + override def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame. This will fail if the table already + * exists. + * + * Note that this currently only works with DataFrame that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + override def saveAsTable(tableName: String): Unit = { + sqlContext.executePlan( + CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + */ + @Experimental + override def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } + + /** + * Return the content of the [[DataFrame]] as a RDD of JSON strings. + */ + override def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + /** + * A helpful function for Py4j, convert a list of Column to an array + */ + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray + } + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala new file mode 100644 index 0000000000000..1f1e9bd9899f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate + + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + */ +class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedDataFrameApi { + + private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + val namedGroupingExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + new DataFrame(df.sqlContext, + Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + } + + private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { + df.numericColumns.map { c => + val a = f(c) + Alias(a, a.toString)() + } + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + expr.toLowerCase match { + case "avg" | "average" | "mean" => Average + case "max" => Max + case "min" => Min + case "sum" => Sum + case "count" | "size" => Count + } + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + override def agg(exprs: Map[String, String]): DataFrame = { + exprs.map { case (colName, expr) => + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.toString)() + }.toSeq + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. + * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]]. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import org.apache.spark.sql.dsl._ + * df.groupBy("department").agg(max($"age"), sum($"expense")) + * }}} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = { + val aggExprs = (expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + + new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + + /** Count the number of rows for each group. */ + override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + */ + override def mean(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the max value for each numeric columns for each group. + */ + override def max(): DataFrame = aggregateNumericColumns(Max) + + /** + * Compute the mean value for each numeric columns for each group. + */ + override def avg(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the min value for each numeric column for each group. + */ + override def min(): DataFrame = aggregateNumericColumns(Min) + + /** + * Compute the sum for each numeric columns for each group. + */ + override def sum(): DataFrame = aggregateNumericColumns(Sum) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala new file mode 100644 index 0000000000000..08cd4d0f3f009 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.types._ + +object Literal { + + /** Return a new boolean literal. */ + def apply(literal: Boolean): Column = new Column(LiteralExpr(literal)) + + /** Return a new byte literal. */ + def apply(literal: Byte): Column = new Column(LiteralExpr(literal)) + + /** Return a new short literal. */ + def apply(literal: Short): Column = new Column(LiteralExpr(literal)) + + /** Return a new int literal. */ + def apply(literal: Int): Column = new Column(LiteralExpr(literal)) + + /** Return a new long literal. */ + def apply(literal: Long): Column = new Column(LiteralExpr(literal)) + + /** Return a new float literal. */ + def apply(literal: Float): Column = new Column(LiteralExpr(literal)) + + /** Return a new double literal. */ + def apply(literal: Double): Column = new Column(LiteralExpr(literal)) + + /** Return a new string literal. */ + def apply(literal: String): Column = new Column(LiteralExpr(literal)) + + /** Return a new decimal literal. */ + def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal)) + + /** Return a new decimal literal. */ + def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal)) + + /** Return a new timestamp literal. */ + def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal)) + + /** Return a new date literal. */ + def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal)) + + /** Return a new binary (byte array) literal. */ + def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal)) + + /** Return a new null literal. */ + def apply(literal: Null): Column = new Column(LiteralExpr(null)) + + /** + * Return a Column expression representing the literal value. Throws an exception if the + * data type is not supported by SparkSQL. + */ + protected[sql] def anyToLiteral(literal: Any): Column = { + // If the literal is a symbol, convert it into a Column. + if (literal.isInstanceOf[Symbol]) { + return dsl.symbolToColumn(literal.asInstanceOf[Symbol]) + } + + val literalExpr = literal match { + case v: Int => LiteralExpr(v, IntegerType) + case v: Long => LiteralExpr(v, LongType) + case v: Double => LiteralExpr(v, DoubleType) + case v: Float => LiteralExpr(v, FloatType) + case v: Byte => LiteralExpr(v, ByteType) + case v: Short => LiteralExpr(v, ShortType) + case v: String => LiteralExpr(v, StringType) + case v: Boolean => LiteralExpr(v, BooleanType) + case v: BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited) + case v: java.math.BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited) + case v: Decimal => LiteralExpr(v, DecimalType.Unlimited) + case v: java.sql.Timestamp => LiteralExpr(v, TimestampType) + case v: java.sql.Date => LiteralExpr(v, DateType) + case v: Array[Byte] => LiteralExpr(v, BinaryType) + case null => LiteralExpr(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) + } + new Column(literalExpr) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a22968cc7807..5030e689c36ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -43,7 +42,7 @@ import org.apache.spark.util.Utils /** * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]] * objects and the execution of SQL queries. * * @groupname userf Spark SQL Functions @@ -53,7 +52,6 @@ import org.apache.spark.util.Utils class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with CacheManager - with ExpressionConversions with Serializable { self => @@ -111,8 +109,8 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + + protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) @@ -124,24 +122,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** - * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. */ - def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - new SchemaRDD(this, LogicalRelation(baseRelation)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): DataFrame = { + new DataFrame(this, LogicalRelation(baseRelation)) } /** * :: DeveloperApi :: - * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: @@ -170,11 +168,11 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new SchemaRDD(this, logicalPlan) + new DataFrame(this, logicalPlan) } /** @@ -183,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -201,7 +199,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ) : Row } } - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -210,35 +208,35 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { applySchema(rdd.rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[SchemaRDD]]. + * Loads a Parquet file, returning the result as a [[DataFrame]]. * * @group userf */ - def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + def parquetFile(path: String): DataFrame = + new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** - * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) /** * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonFile(path: String, schema: StructType): SchemaRDD = { + def jsonFile(path: String, schema: StructType): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, schema) } @@ -247,29 +245,29 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, samplingRatio) } /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[SchemaRDD]]. + * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( @@ -283,7 +281,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( @@ -298,8 +296,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) + def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), rdd.logicalPlan) } /** @@ -321,17 +319,17 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { + def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new SchemaRDD(this, parseSql(sqlText)) + new DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } /** Returns the specified table as a SchemaRDD */ - def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) + def table(tableName: String): DataFrame = + new DataFrame(this, catalog.lookupRelation(Seq(tableName))) /** * A collection of methods that are considered experimental, but can be used to hook into @@ -454,15 +452,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected abstract class QueryExecution { - def logical: LogicalPlan + protected class QueryExecution(val logical: LogicalPlan) { - lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData = useCachedData(analyzed) - lazy val optimizedPlan = optimizer(withCachedData) + lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) + lazy val withCachedData: LogicalPlan = useCachedData(analyzed) + lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... - lazy val sparkPlan = { + lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) planner(optimizedPlan).next() } @@ -512,7 +509,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schemaString: String): SchemaRDD = { + schemaString: String): DataFrame = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } @@ -522,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schema: StructType): SchemaRDD = { + schema: StructType): DataFrame = { def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -549,7 +546,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala deleted file mode 100644 index d1e21dffeb8c5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ /dev/null @@ -1,511 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import net.razorvine.pickle.Pickler - -import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{BooleanType, StructType} -import org.apache.spark.storage.StorageLevel - -/** - * :: AlphaComponent :: - * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, - * SchemaRDDs can be used in relational queries, as shown in the examples below. - * - * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD - * whose elements are scala case classes into a SchemaRDD. This conversion can also be done - * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. - * - * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] - * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. - * - * == SQL Queries == - * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once - * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. - * - * {{{ - * // One method for defining the schema of an RDD is to make a case class with the desired column - * // names and types. - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - * // Any RDD containing case classes can be registered as a table. The schema of the table is - * // automatically inferred using scala reflection. - * rdd.registerTempTable("records") - * - * val results: SchemaRDD = sql("SELECT * FROM records") - * }}} - * - * == Language Integrated Queries == - * - * {{{ - * - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) - * - * // Example of language integrated queries. - * rdd.where('key === 1).orderBy('value.asc).select('key).collect() - * }}} - * - * @groupname Query Language Integrated Queries - * @groupdesc Query Functions that create new queries from SchemaRDDs. The - * result of all query functions is also a SchemaRDD, allowing multiple operations to be - * chained using a builder pattern. - * @groupprio Query -2 - * @groupname schema SchemaRDD Functions - * @groupprio schema -1 - * @groupname Ungrouped Base RDD Functions - */ -@AlphaComponent -class SchemaRDD( - @transient val sqlContext: SQLContext, - @transient val baseLogicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - - def baseSchemaRDD = this - - // ========================================================================================= - // RDD functions: Copy the internal row representation so we present immutable data to users. - // ========================================================================================= - - override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) - - override def getPartitions: Array[Partition] = firstParent[Row].partitions - - override protected def getDependencies: Seq[Dependency[_]] = { - schema // Force reification of the schema so it is available on executors. - - List(new OneToOneDependency(queryExecution.toRdd)) - } - - /** - * Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - lazy val schema: StructType = queryExecution.analyzed.schema - - /** - * Returns a new RDD with each row transformed to a JSON string. - * - * @group schema - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - - // ======================================================================= - // Query DSL - // ======================================================================= - - /** - * Changes the output of this relation to the given expressions, similar to the `SELECT` clause - * in SQL. - * - * {{{ - * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) - * }}} - * - * @param exprs a set of logical expression that will be evaluated for each input row. - * - * @group Query - */ - def select(exprs: Expression*): SchemaRDD = { - val aliases = exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) - } - - /** - * Filters the output, only returning those rows where `condition` evaluates to true. - * - * {{{ - * schemaRDD.where('a === 'b) - * schemaRDD.where('a === 1) - * schemaRDD.where('a + 'b > 10) - * }}} - * - * @group Query - */ - def where(condition: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) - - /** - * Performs a relational join on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be joined with this one. - * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param on An optional condition for the join operation. This is equivalent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivalent to adding `where` clauses after the `join`. - * - * @group Query - */ - def join( - otherPlan: SchemaRDD, - joinType: JoinType = Inner, - on: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) - - /** - * Sorts the results by the given expressions. - * {{{ - * schemaRDD.orderBy('a) - * schemaRDD.orderBy('a, 'b) - * schemaRDD.orderBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) - - /** - * Sorts the results by the given expressions within partition. - * {{{ - * schemaRDD.sortBy('a) - * schemaRDD.sortBy('a, 'b) - * schemaRDD.sortBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) - - @deprecated("use limit with integer argument", "1.1.0") - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) - - /** - * Limits the results by the given integer. - * {{{ - * schemaRDD.limit(10) - * }}} - * @group Query - */ - def limit(limitNum: Int): SchemaRDD = - new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) - - /** - * Performs a grouping followed by an aggregation. - * - * {{{ - * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() - } - new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) - } - - /** - * Performs an aggregation over all Rows in this RDD. - * This is equivalent to a groupBy with no grouping expressions. - * - * {{{ - * schemaRDD.aggregate(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def aggregate(aggregateExprs: Expression*): SchemaRDD = { - groupBy()(aggregateExprs: _*) - } - - /** - * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes - * with the same name, for example, when performing self-joins. - * - * {{{ - * val x = schemaRDD.where('a === 1).as('x) - * val y = schemaRDD.where('a === 2).as('y) - * x.join(y).where("x.a".attr === "y.a".attr), - * }}} - * - * @group Query - */ - def as(alias: Symbol) = - new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) - - /** - * Combines the tuples of two RDDs with the same schema, keeping duplicates. - * - * @group Query - */ - def unionAll(otherPlan: SchemaRDD) = - new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational except on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. - * - * @group Query - */ - def except(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational intersect on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. - * - * @group Query - */ - def intersect(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) - - /** - * Filters tuples using a function over the value of the specified column. - * - * {{{ - * schemaRDD.where('a)((a: Int) => ...) - * }}} - * - * @group Query - */ - def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - - /** - * :: Experimental :: - * Returns a sampled version of the underlying dataset. - * - * @group Query - */ - @Experimental - override - def sample( - withReplacement: Boolean = true, - fraction: Double, - seed: Long) = - new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) - - /** - * :: Experimental :: - * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this - * implementation leverages the query optimizer to compute the count on the SchemaRDD, which - * supports features such as filter pushdown. - * - * @group Query - */ - @Experimental - override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) - - /** - * :: Experimental :: - * Applies the given Generator, or table generating function, to this relation. - * - * @param generator A table generating function. The API for such functions is likely to change - * in future releases - * @param join when set to true, each output row of the generator is joined with the input row - * that produced it. - * @param outer when set to true, at least one row will be produced for each input row, similar to - * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a - * given row, a single row will be output, with `NULL` values for each of the - * generated columns. - * @param alias an optional alias that can be used as qualifier for the attributes that are - * produced by this generate operation. - * - * @group Query - */ - @Experimental - def generate( - generator: Generator, - join: Boolean = false, - outer: Boolean = false, - alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) - - /** - * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit - * conversion from a standard RDD to a SchemaRDD. - * - * @group schema - */ - def toSchemaRDD = this - - /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. - */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same - * format as javaToPython. It is used by pyspark. - */ - private[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same - * format as javaToPython and collectToPython. It is used by pyspark. - */ - private[sql] def takeSampleToPython( - withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value - * of base RDD functions that do not change schema. - * - * @param rdd RDD derived from this one and has same schema - * - * @group schema - */ - private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, - LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) - } - - // ======================================================================= - // Overridden RDD actions - // ======================================================================= - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*) - - override def take(num: Int): Array[Row] = limit(num).collect() - - // ======================================================================= - // Base RDD functions that do NOT change schema - // ======================================================================= - - // Transformations (return a new RDD) - - override def coalesce(numPartitions: Int, shuffle: Boolean = false) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.coalesce(numPartitions, shuffle)(ord)) - - override def distinct(): SchemaRDD = applySchema(super.distinct()) - - override def distinct(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.distinct(numPartitions)(ord)) - - def distinct(numPartitions: Int): SchemaRDD = - applySchema(super.distinct(numPartitions)(null)) - - override def filter(f: Row => Boolean): SchemaRDD = - applySchema(super.filter(f)) - - override def intersection(other: RDD[Row]): SchemaRDD = - applySchema(super.intersection(other)) - - override def intersection(other: RDD[Row], partitioner: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.intersection(other, partitioner)(ord)) - - override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.intersection(other, numPartitions)) - - override def repartition(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.repartition(numPartitions)(ord)) - - override def subtract(other: RDD[Row]): SchemaRDD = - applySchema(super.subtract(other)) - - override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.subtract(other, numPartitions)) - - override def subtract(other: RDD[Row], p: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.subtract(other, p)(ord)) - - /** Overridden cache function will always use the in-memory columnar caching. */ - override def cache(): this.type = { - sqlContext.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.tryUncacheQuery(this, blocking) - this - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala deleted file mode 100644 index 3cf9209465b76..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD - -/** - * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) - */ -private[sql] trait SchemaRDDLike { - @transient def sqlContext: SQLContext - @transient val baseLogicalPlan: LogicalPlan - - private[sql] def baseSchemaRDD: SchemaRDD - - /** - * :: DeveloperApi :: - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. This workflow is produced lazily because - * invoking the whole query optimization pipeline can be expensive. - * - * The query execution is considered a Developer API as phases may be added or removed in future - * releases. This execution is only exposed to provide an interface for inspecting the various - * phases for debugging purposes. Applications should not depend on particular phases existing - * or producing any specific output, even for exactly the same query. - * - * Additionally, the RDD exposed by this execution is not designed for consumption by end users. - * In particular, it does not contain any schema information, and it reuses Row objects - * internally. This object reuse improves performance, but can make programming against the RDD - * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly. - */ - @transient - @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } - - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.simpleString}""".stripMargin.trim - - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) - } - - @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") - def registerAsTable(tableName: String): Unit = registerTempTable(tableName) - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - - /** - * :: Experimental :: - * Appends the rows from this RDD to the specified table. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * :: Experimental :: - * Creates a table from the the contents of this SchemaRDD. This will fail if the table already - * exists. - * - * Note that this currently only works with SchemaRDDs that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * @group schema - */ - @Experimental - def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd - - /** Returns the schema as a string in the tree format. - * - * @group schema - */ - def schemaString: String = baseSchemaRDD.schema.treeString - - /** Prints out the schema. - * - * @group schema - */ - def printSchema(): Unit = println(schemaString) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala new file mode 100644 index 0000000000000..073d41e938478 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -0,0 +1,289 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +trait RDDApi[T] { + + def cache(): this.type = persist() + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type = unpersist(blocking = false) + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame +} + + +/** + * An internal interface defining data frame related methods in [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +trait DataFrameSpecificApi { + + def schema: StructType + + def printSchema(): Unit + + def dtypes: Array[(String, String)] + + def columns: Array[String] + + def head(): Row + + def head(n: Int): Array[Row] + + ///////////////////////////////////////////////////////////////////////////// + // Relational operators + ///////////////////////////////////////////////////////////////////////////// + def apply(colName: String): Column + + def apply(projection: Product): DataFrame + + @scala.annotation.varargs + def select(cols: Column*): DataFrame + + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame + + def apply(condition: Column): DataFrame + + def as(name: String): DataFrame + + def filter(condition: Column): DataFrame + + def where(condition: Column): DataFrame + + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataFrame + + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): GroupedDataFrame + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def sort(colName: String): DataFrame + + @scala.annotation.varargs + def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortExpr: Column, sortExprs: Column*): DataFrame + + def join(right: DataFrame): DataFrame + + def join(right: DataFrame, joinExprs: Column): DataFrame + + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + + def limit(n: Int): DataFrame + + def unionAll(other: DataFrame): DataFrame + + def intersect(other: DataFrame): DataFrame + + def except(other: DataFrame): DataFrame + + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + + def sample(withReplacement: Boolean, fraction: Double): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // Column mutation + ///////////////////////////////////////////////////////////////////////////// + def addColumn(colName: String, col: Column): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // I/O and interaction with other frameworks + ///////////////////////////////////////////////////////////////////////////// + + def rdd: RDD[Row] + + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + def toJSON: RDD[String] + + def registerTempTable(tableName: String): Unit + + def saveAsParquetFile(path: String): Unit + + @Experimental + def saveAsTable(tableName: String): Unit + + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit + + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + ///////////////////////////////////////////////////////////////////////////// + // Stat functions + ///////////////////////////////////////////////////////////////////////////// +// def describe(): Unit +// +// def mean(): Unit +// +// def max(): Unit +// +// def min(): Unit +} + + +/** + * An internal interface defining expression APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this. + */ +trait ExpressionApi { + + def isComputable: Boolean + + def unary_- : Column + def unary_! : Column + def unary_~ : Column + + def + (other: Column): Column + def + (other: Any): Column + def - (other: Column): Column + def - (other: Any): Column + def * (other: Column): Column + def * (other: Any): Column + def / (other: Column): Column + def / (other: Any): Column + def % (other: Column): Column + def % (other: Any): Column + def & (other: Column): Column + def & (other: Any): Column + def | (other: Column): Column + def | (other: Any): Column + def ^ (other: Column): Column + def ^ (other: Any): Column + + def && (other: Column): Column + def && (other: Boolean): Column + def || (other: Column): Column + def || (other: Boolean): Column + + def < (other: Column): Column + def < (other: Any): Column + def <= (other: Column): Column + def <= (other: Any): Column + def > (other: Column): Column + def > (other: Any): Column + def >= (other: Column): Column + def >= (other: Any): Column + def === (other: Column): Column + def === (other: Any): Column + def equalTo(other: Column): Column + def equalTo(other: Any): Column + def <=> (other: Column): Column + def <=> (other: Any): Column + def !== (other: Column): Column + def !== (other: Any): Column + + @scala.annotation.varargs + def in(list: Column*): Column + + def like(other: Column): Column + def like(other: String): Column + def rlike(other: Column): Column + def rlike(other: String): Column + + def contains(other: Column): Column + def contains(other: Any): Column + def startsWith(other: Column): Column + def startsWith(other: String): Column + def endsWith(other: Column): Column + def endsWith(other: String): Column + + def substr(startPos: Column, len: Column): Column + def substr(startPos: Int, len: Int): Column + + def isNull: Column + def isNotNull: Column + + def getItem(ordinal: Column): Column + def getItem(ordinal: Int): Column + def getField(fieldName: String): Column + + def cast(to: DataType): Column + + def asc: Column + def desc: Column + + def as(alias: String): Column +} + + +/** + * An internal interface defining aggregation APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this. + */ +trait GroupedDataFrameApi { + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def avg(): DataFrame + + def mean(): DataFrame + + def min(): DataFrame + + def max(): DataFrame + + def sum(): DataFrame + + def count(): DataFrame + + // TODO: Add var, std +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala new file mode 100644 index 0000000000000..29c3d26ae56d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + + +package object dsl { + + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) extends AnyVal { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + + def sum(e: Column): Column = Sum(e.expr) + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def count(e: Column): Column = Count(e.expr) + + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + def avg(e: Column): Column = Average(e.expr) + def first(e: Column): Column = First(e.expr) + def last(e: Column): Column = Last(e.expr) + def min(e: Column): Column = Min(e.expr) + def max(e: Column): Column = Max(e.expr) + def upper(e: Column): Column = Upper(e.expr) + def lower(e: Column): Column = Lower(e.expr) + def sqrt(e: Column): Column = Sqrt(e.expr) + def abs(e: Column): Column = Abs(e.expr) + + // scalastyle:off + + object literals { + + implicit def booleanToLiteral(b: Boolean): Column = Literal(b) + + implicit def byteToLiteral(b: Byte): Column = Literal(b) + + implicit def shortToLiteral(s: Short): Column = Literal(s) + + implicit def intToLiteral(i: Int): Column = Literal(i) + + implicit def longToLiteral(l: Long): Column = Literal(l) + + implicit def floatToLiteral(f: Float): Column = Literal(f) + + implicit def doubleToLiteral(d: Double): Column = Literal(d) + + implicit def stringToLiteral(s: String): Column = Literal(s) + + implicit def dateToLiteral(d: Date): Column = Literal(d) + + implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying()) + + implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d) + + implicit def timestampToLiteral(t: Timestamp): Column = Literal(t) + + implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a) + } + + + /* Use the following code to generate: + (0 to 22).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) + }""") + } + + (0 to 22).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag](f: Function0[RT]): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 52a31f01a4358..6fba76c52171b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical @@ -137,7 +137,9 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + plan.foreach { logicalPlan => + sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName) + } sqlContext.cacheTable(tableName) if (!isLazy) { @@ -159,7 +161,7 @@ case class CacheTableCommand( case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - sqlContext.table(tableName).unpersist() + sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 4d7e338e8ed13..aeb0960e87f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -42,7 +42,7 @@ package object debug { * Augments SchemaRDDs with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: SchemaRDD) { + implicit class DebugQuery(query: DataFrame) { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6dd39be807037..7c49b5220d607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -37,5 +37,5 @@ package object sql { * Converts a logical plan into zero or more SparkPlans. */ @DeveloperApi - type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 02ce1b3e6d811..0b312ef51daa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils @@ -100,7 +100,7 @@ trait ParquetTest { */ protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) - (f: SchemaRDD => Unit): Unit = { + (f: DataFrame => Unit): Unit = { withParquetFile(data)(path => f(parquetFile(path))) } @@ -120,7 +120,7 @@ trait ParquetTest { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetRDD(data) { rdd => - rdd.registerTempTable(tableName) + sqlContext.registerRDDAsTable(rdd, tableName) withTempTable(tableName)(f) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d03019..d13f2ce2a5e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, @@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy { } } + /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => + LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => + GreaterThanOrEqual(a.name, v) case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 171b816a26332..b4af91a768efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SchemaRDD, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand @@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) + sqlContext.registerRDDAsTable( + new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f9c082216085d..2564c849b87f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ @@ -40,8 +40,8 @@ object TestSQLContext * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to * construct SchemaRDD directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { - new SchemaRDD(this, plan) + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + new DataFrame(this, plan) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java index 9ff40471a00af..e5588938ea162 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -61,7 +61,7 @@ public Integer call(String str) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); assert(result.getInt(0) == 4); } @@ -81,7 +81,7 @@ public Integer call(String str1, String str2) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); assert(result.getInt(0) == 9); } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9e96738ac095a..badd00d34b9b1 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,8 +98,8 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - schemaRDD.registerTempTable("people"); + DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); + df.registerTempTable("people"); Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -147,17 +147,17 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = schemaRDD1.schema(); + DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); + StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerTempTable("jsonTable1"); + df1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = schemaRDD2.schema(); + DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); + StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD2.registerTempTable("jsonTable2"); + df2.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index cfc037caff2a9..34763156a6d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index afbfe214f1ce4..a5848f219cea9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types._ /* Implicits */ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import scala.language.postfixOps @@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, sum('b)), + testData2.groupBy("a").agg($"a", sum($"b")), Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) checkAnswer( - testData2.aggregate(sum('b)), + testData2.agg(sum('b)), Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( - testData.where($"key" === 1).select($"value"), + testData.where($"key" === Literal(1)).select($"value"), Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === Literal(1)).select('value), Row("1")) } test("select *") { checkAnswer( - testData.select(Star(None)), + testData.select($"*"), testData.collect().toSeq) } test("simple select") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === Literal(1)).select('value), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(1)), + testData.select(sum('value), avg('value), count(Literal(1))), Row(5050.0, 50.5, 100)) checkAnswer( @@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) - } - - test("partition wide sorting") { - // 2 partitions totally, and - // Partition #1 with values: - // (1, 1) - // (1, 2) - // (2, 1) - // Partition #2 with values: - // (2, 2) - // (3, 1) - // (3, 2) - checkAnswer( - testData2.sortBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.sortBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.desc), - Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.asc), - Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest { mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("SPARK-3395 limit distinct") { - val filtered = TestData.testData2 - .distinct() - .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) - .limit(1) - .registerTempTable("onerow") - checkAnswer( - sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: Nil) - } - - test("SPARK-3858 generator qualifiers are discarded") { - checkAnswer( - arrayData.as('ad) - .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) - .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Row(_))) - } - test("average") { checkAnswer( - testData2.aggregate(avg('a)), + testData2.agg(avg('a)), Row(2.0)) checkAnswer( - testData2.aggregate(avg('a), sumDistinct('a)), // non-partial + testData2.agg(avg('a), sumDistinct('a)), // non-partial Row(2.0, 6.0) :: Nil) checkAnswer( - decimalData.aggregate(avg('a)), + decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial + decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2))), + decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( - testData3.aggregate(avg('b)), + testData3.agg(avg('b)), Row(2.0)) checkAnswer( - testData3.aggregate(avg('b), countDistinct('b)), + testData3.agg(avg('b), countDistinct('b)), Row(2.0, 1)) checkAnswer( - testData3.aggregate(avg('b), sumDistinct('b)), // non-partial + testData3.agg(avg('b), sumDistinct('b)), // non-partial Row(2.0, 2.0)) } test("zero average") { checkAnswer( - emptyTableData.aggregate(avg('a)), + emptyTableData.agg(avg('a)), Row(null)) checkAnswer( - emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial Row(null, null)) } @@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest { assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( - testData2.aggregate(count('a), sumDistinct('a)), // non-partial + testData2.agg(count('a), sumDistinct('a)), // non-partial Row(6, 6.0)) } test("null count") { checkAnswer( - testData3.groupBy('a)('a, count('b)), + testData3.groupBy('a).agg('a, count('b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.groupBy('a)('a, count('a + 'b)), + testData3.groupBy('a).agg('a, count('a + 'b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), + testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)), Row(2, 1, 2, 2, 1) ) checkAnswer( - testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial Row(1, 1, 2) ) } @@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest { assert(emptyTableData.count() === 0) checkAnswer( - emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("zero sum") { checkAnswer( - emptyTableData.aggregate(sum('a)), + emptyTableData.agg(sum('a)), Row(null)) } test("zero sum distinct") { checkAnswer( - emptyTableData.aggregate(sumDistinct('a)), + emptyTableData.agg(sumDistinct('a)), Row(null)) } @@ -320,7 +271,7 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select(Star(None), foo.call('key, 'value)).limit(3), + testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } @@ -362,7 +313,7 @@ class DslQuerySuite extends QueryTest { test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase())) + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) ) checkAnswer( @@ -379,7 +330,7 @@ class DslQuerySuite extends QueryTest { test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase())) + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index cd36da7751e83..79713725c0d77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } test("inner join where, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + upperCaseData.join(lowerCaseData).where('n === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("inner join ON, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) + val x = testData2.where($"a" === Literal(1)).as("x") + val y = testData2.where($"a" === Literal(1)).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Row(1,1,1,1) :: Row(1,1,1,2) :: Row(1,2,1,1) :: @@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) + val x = testData2.where($"a" === Literal(1)).as("x") + val y = testData2.where($"a" === Literal(2)).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Nil) } test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + bigDataX.join(bigDataY).where($"x.key" === $"y.key"), + testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { @@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left outer join") { checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("right outer join") { checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -299,14 +298,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= Literal(4)).registerTempTable("left") + upperCaseData.where('N >= Literal(3)).registerTempTable("right") val left = UnresolvedRelation(Seq("left"), None) val right = UnresolvedRelation(Seq("right"), None) checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + left.join(right, $"left.N" === $"right.N", "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", 3, "C") :: @@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 42a21c148df53..07c52de377a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -26,12 +26,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -44,10 +44,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -91,7 +91,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } @@ -102,7 +102,7 @@ class QueryTest extends PlanTest { } /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..4fff99cb3f3e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -381,8 +383,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("big inner join, 4 matches per row") { - - checkAnswer( sql( """ @@ -396,7 +396,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.flatMap( + testData.rdd.flatMap( row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } @@ -742,7 +742,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("metadata is propagated correctly") { - val person = sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -751,14 +751,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person, schemaWithMeta) - def validateMetadata(rdd: SchemaRDD): Unit = { + val personWithMeta = applySchema(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } personWithMeta.registerTempTable("personWithMeta") - validateMetadata(personWithMeta.select('name)) - validateMetadata(personWithMeta.select("name".attr)) - validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"id", $"name")) validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 808ed5288cfb8..fffa2b7dfa6e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test._ /* Implicits */ @@ -29,11 +30,11 @@ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD + (1 to 100).map(i => TestData(i, i.toString))).toDF testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +45,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD + LargeAndSmallInts(3, 2) :: Nil).toDF largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +56,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toDF testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +68,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toSchemaRDD + DecimalData(3, 2) :: Nil).toDF decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,17 +78,17 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD + BinaryData("123".getBytes(), 4) :: Nil).toDF binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toSchemaRDD + TestData3(2, Some(2)) :: Nil).toDF testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation($"a".int, $"b".int) case class UpperCaseData(N: Int, L: String) val upperCaseData = @@ -97,7 +98,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toSchemaRDD + UpperCaseData(6, "F") :: Nil).toDF upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +107,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toSchemaRDD + LowerCaseData(4, "d") :: Nil).toDF lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -200,6 +201,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toSchemaRDD + :: Nil).toDF complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0c98120031242..5abd7b9383366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ @@ -28,17 +29,17 @@ class UDFSuite extends QueryTest { test("Simple UDF") { udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { @@ -46,7 +47,7 @@ class UDFSuite extends QueryTest { val result= sql("SELECT returnStruct('test', 'test2') as ret") - .select("ret.f1".attr).first().getString(0) - assert(result == "test") + .select($"ret.f1").head().getString(0) + assert(result === "test") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fbc8704f7837b..62b2e89403791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index e61f3c39631da..6f051dfe3d21d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 67007b8c093ca..be5e63c76f42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } @@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite { testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") - val a = testData.as('a) - val b = table("tiny").as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + val a = testData.as("a") + val b = table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala deleted file mode 100644 index 272c0d4cb2335..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ - -/* Implicit conversions */ -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns - * from the input data. These will be replaced during analysis with specific AttributeReferences - * and then bound to specific ordinals during query planning. While TGFs could also access specific - * columns using hand-coded ordinals, doing so violates data independence. - * - * Note: this is only a rough example of how TGFs can be expressed, the final version will likely - * involve a lot more sugar for cleaner use in Scala/Java/etc. - */ -case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { - def children = input - protected def makeOutput() = 'nameAndAge.string :: Nil - - val Seq(nameAttr, ageAttr) = input - - override def eval(input: Row): TraversableOnce[Row] = { - val name = nameAttr.eval(input) - val age = ageAttr.eval(input).asInstanceOf[Int] - - Iterator( - new GenericRow(Array[Any](s"$name is $age years old")), - new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) - } -} - -class TgfSuite extends QueryTest { - val inputData = - logical.LocalRelation('name.string, 'age.int).loadData( - ("michael", 29) :: Nil - ) - - test("simple tgf example") { - checkAnswer( - inputData.generate(ExampleTGF()), - Seq( - Row("michael is 29 years old"), - Row("Next year, michael will be 30 years old"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 94d14acccbb18..ef198f846c53a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.{Literal, QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ @@ -463,8 +464,8 @@ class JsonSuite extends QueryTest { // in the Project. checkAnswer( jsonSchemaRDD. - where('num_str > BigDecimal("92233720368547758060")). - select('num_str + 1.2 as Symbol("num")), + where('num_str > Literal(BigDecimal("92233720368547758060"))). + select(('num_str + Literal(1.2)).as("num")), Row(new java.math.BigDecimal("92233720368547758061.2")) ) @@ -820,7 +821,7 @@ class JsonSuite extends QueryTest { val schemaRDD1 = applySchema(rowRDD1, schema1) schemaRDD1.registerTempTable("applySchema1") - val schemaRDD2 = schemaRDD1.toSchemaRDD + val schemaRDD2 = schemaRDD1.toDF val result = schemaRDD2.toJSON.collect() assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") @@ -841,7 +842,7 @@ class JsonSuite extends QueryTest { val schemaRDD3 = applySchema(rowRDD2, schema2) schemaRDD3.registerTempTable("applySchema2") - val schemaRDD4 = schemaRDD3.toSchemaRDD + val schemaRDD4 = schemaRDD3.toDF val result2 = schemaRDD4.toJSON.collect() assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 1e7d3e06fc196..c9bc55900de98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -23,7 +23,7 @@ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -41,15 +41,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: SchemaRDD, + rdd: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Seq[Row]) => Unit, + checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output: _*).where(predicate) + val query = rdd + .select(output.map(e => new org.apache.spark.sql.Column(e)): _*) + .where(new org.apache.spark.sql.Column(predicate)) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -71,13 +73,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -93,24 +95,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -118,24 +120,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -143,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -168,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -197,30 +199,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } @@ -231,7 +233,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -249,16 +251,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a57e4e85a35ef..f03b3a32e34e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): SchemaRDD = + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select('_1 cast decimal) + .select($"_1" cast decimal as "abcd") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 7900b3e8948d9..a33cf1172cac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import scala.language.existentials + import org.apache.spark.sql._ import org.apache.spark.sql.types._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7385952861ee5..bb19ac232fcbe 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -23,6 +23,7 @@ import java.io._ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 166c56b9dfe20..ea9d61d8d0f5e 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation( sessionToActivePool: SMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index eaf7a1ddd4996..71e3954b2c7ac 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation( // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d2cfd8e0d669..b746942cb1067 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) - override def sql(sqlText: String): SchemaRDD = { + override def sql(sqlText: String): DataFrame = { val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) + new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } @@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] val planner = hivePlanner /** Extends QueryExecution with hive specific features. */ - protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected[sql] class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { /** * Returns the result as a hive compatible sequence of strings. For native commands, the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6952b126cf894..ace9329cd5821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -55,16 +55,15 @@ private[hive] trait HiveStrategies { */ @Experimental object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: SchemaRDD) { - def lowerCase = - new SchemaRDD(s.sqlContext, s.logicalPlan) + implicit class LogicalPlanHacks(s: DataFrame) { + def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new SchemaRDD( + new DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = otherPredicates.map(_ transform { + val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true)) + }).reduceOption(And).getOrElse(Literal(true))) - val unresolvedProjection = projectList.map(_ transform { + val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }) + }).map(new Column(_)) try { if (relation.hiveQlTable.isPartitioned) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 47431cef03e13..8e70ae8f56196 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -150,8 +150,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(hql) + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(hql)) { def hiveExec() = runSqlHive(hql) override def toString = hql + "\n" + super.toString } @@ -159,7 +159,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** * Override QueryExecution with special debug workflow. */ - abstract class QueryExecution extends super.QueryExecution { + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { override lazy val analyzed = { val describedTables = logical match { case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f320d732fb77a..ba391293884bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -36,12 +36,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -54,10 +54,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -101,7 +101,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f95a6b43af357..61e5117feab10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { @@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest { * Throws a test failed exception when the number of cached tables differs from the expected * number. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 0e6636d38ed3c..5775d83fcbf67 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq + testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq ) // Now overwrite. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac6..d67b00bc9d08f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,11 +27,12 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -473,7 +474,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: SchemaRDD) = { + def isExplanation(result: DataFrame) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } @@ -842,7 +843,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = + def collectResults(rdd: DataFrame): Set[(String, String)] = rdd.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 16f77a438e1ae..a081227b4e6b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("create table spark_4959 (col1 string)") sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( - 'col1.as('CaseSensitiveColName), - 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f2374a215291b..dd0df1a9f6320 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f3, | getStruct(1).f4, | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
    Property NameDefaultMeaning
    spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. +
    spark.driver.extraClassPath(none) + Extra classpath entries to append to the classpath of the driver. +
    spark.driver.extraLibraryPath(none) + Set a special library path to use when launching the driver JVM. +
    spark.executor.extraJavaOptions (none)OutputShuffle Read Blocked Time
    + {shuffleReadBlockedTimeReadable} + {shuffleReadReadable}