diff --git a/assembly/pom.xml b/assembly/pom.xml
index b2a9d0780ee2b..594fa0c779e1b 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -142,8 +142,10 @@
com/google/common/base/Absent*
+ com/google/common/base/Function
com/google/common/base/Optional*
com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/bin/spark-class b/bin/spark-class
index 1b945461fabc8..2f0441bb3c1c2 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -120,8 +121,8 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
diff --git a/build/mvn b/build/mvn
index 43471f83e904c..f91e2b4bdcc02 100755
--- a/build/mvn
+++ b/build/mvn
@@ -68,10 +68,10 @@ install_app() {
# Install maven under the build/ folder
install_mvn() {
install_app \
- "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \
- "apache-maven-3.2.3-bin.tar.gz" \
- "apache-maven-3.2.3/bin/mvn"
- MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn"
+ "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
+ "apache-maven-3.2.5-bin.tar.gz" \
+ "apache-maven-3.2.5/bin/mvn"
+ MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
}
# Install zinc under the build/ folder
diff --git a/core/pom.xml b/core/pom.xml
index d9a49c9e08afc..1984682b9c099 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -372,8 +372,10 @@
com.google.guava:guava
com/google/common/base/Absent*
+ com/google/common/base/Function
com/google/common/base/Optional*
com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index a1f7133f897ee..f23ba9dba167f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -190,6 +190,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
-.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time {
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
+.getting_result_time {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index a0ce107f43b16..cd91c8f87547b 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,8 +17,11 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
/**
@@ -46,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ set(k, v)
}
}
@@ -63,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -129,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -163,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -224,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -240,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -265,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -345,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6a354ed4d1486..4c4ee04cc515e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1468,13 +1505,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /** Default min number of partitions for Hadoop RDDs when not given by user */
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
+ * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
+ * The reasons for this are discussed in https://github.com/mesos/spark/pull/718
+ */
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4d418037bd33f..1264a8126153b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -326,6 +326,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 57f9faf5ddd1d..211e3ede53d9c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
@@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesWritten = f()
@@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging {
}
}
- private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
- val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
- val scheme = qualifiedPath.toUri().getScheme()
- val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
+ val stats = FileSystem.getAllStatistics()
stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2b084a2d73b78..0ae45f4ad9130 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (!logInfos.isEmpty) {
val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
def addIfAbsent(info: FsApplicationHistoryInfo) = {
- if (!newApps.contains(info.id)) {
+ if (!newApps.contains(info.id) ||
+ newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) &&
+ !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) {
newApps += (info.id -> info)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 9a4adfbbb3d71..823825302658c 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case x: DisassociatedEvent =>
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
+ if (x.remoteAddress == driver.anchorPath.address) {
+ logError(s"Driver $x disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"Received irrelevant DisassociatedEvent $x")
+ }
case StopExecutor =>
logInfo("Driver commanded a shutdown")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 42566d1a14093..d8c2e41a7c715 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
+ executorHostname: String,
env: SparkEnv,
isLocal: Boolean = false)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -58,12 +61,12 @@ private[spark] class Executor(
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
+ Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index ddb5903bf6875..97912c68c5982 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -19,7 +19,6 @@ package org.apache.spark.executor
import java.util.concurrent.atomic.AtomicLong
-import org.apache.spark.executor.DataReadMethod
import org.apache.spark.executor.DataReadMethod.DataReadMethod
import scala.collection.mutable.ArrayBuffer
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 45633e3de01dd..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -130,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 056aef0bc210a..c3e3931042de2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.mapred.JobID
import org.apache.hadoop.mapred.TaskAttemptID
import org.apache.hadoop.mapred.TaskID
+import org.apache.hadoop.mapred.lib.CombineFileSplit
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
@@ -218,13 +219,13 @@ class HadoopRDD[K, V](
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse(
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
split.inputSplit.value match {
- case split: FileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf)
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
- )
+ }
inputMetrics.setBytesReadCallback(bytesReadCallback)
var reader: RecordReader[K, V] = null
@@ -254,7 +255,8 @@ class HadoopRDD[K, V](
reader.close()
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
- } else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
+ split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7b0e3c87ccff4..d86f95ac3e485 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
@@ -34,7 +34,7 @@ import org.apache.spark.Logging
import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
@@ -114,13 +114,13 @@ class NewHadoopRDD[K, V](
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse(
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
split.serializableHadoopSplit.value match {
- case split: FileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf)
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
- )
+ }
inputMetrics.setBytesReadCallback(bytesReadCallback)
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
@@ -163,7 +163,8 @@ class NewHadoopRDD[K, V](
reader.close()
if (bytesReadCallback.isDefined) {
inputMetrics.updateBytesRead()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 0f37d830ef34f..49b88a90ab5af 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
@@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
@@ -1086,11 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.commitJob()
}
- private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
- : (OutputMetrics, Option[() => Long]) = {
- val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
- .map(new Path(_))
- .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
if (bytesWrittenCallback.isDefined) {
context.taskMetrics.outputMetrics = Some(outputMetrics)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 97012c7033f9f..ab7410a1f7f99 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 6f446c5a95a0a..4307029d44fbb 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,8 +24,10 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
- val TASK_DESERIALIZATION_TIME =
- """Time spent deserializating the task closure on the executor."""
+ val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor."
+
+ val SHUFFLE_READ_BLOCKED_TIME =
+ "Time that the task spent blocked waiting for shuffle data to be read from remote machines."
val INPUT = "Bytes read from Hadoop or from Spark storage."
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 09a936c2234c0..d8be1b20b3acd 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
Task Deserialization Time
+ {if (hasShuffleRead) {
+
+
+
+ Shuffle Read Blocked Time
+
+
+ }}
@@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
{if (hasInput) Seq(("Input", "")) else Nil} ++
{if (hasOutput) Seq(("Output", "")) else Nil} ++
- {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++
+ {if (hasShuffleRead) {
+ Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
+ ("Shuffle Read", ""))
+ } else {
+ Nil
+ }} ++
{if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++
{if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
else Nil} ++
@@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes)
+ val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
+ }
+ val shuffleReadBlockedQuantiles = Shuffle Read Blocked Time +:
+ getFormattedTimeQuantiles(shuffleReadBlockedTimes)
+
val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
}
@@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{gettingResultQuantiles} ,
if (hasInput) {inputQuantiles} else Nil,
if (hasOutput) {outputQuantiles} else Nil,
- if (hasShuffleRead) {shuffleReadQuantiles} else Nil,
+ if (hasShuffleRead) {
+
+ {shuffleReadBlockedQuantiles}
+
+ {shuffleReadQuantiles}
+ } else {
+ Nil
+ },
if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil,
if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil,
if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil)
@@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
.map(m => s"${Utils.bytesToString(m.bytesWritten)}")
.getOrElse("")
+ val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime)
+ val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("")
+ val shuffleReadBlockedTimeReadable =
+ maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("")
+
val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead)
val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("")
@@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}}
{if (hasShuffleRead) {
+
+ {shuffleReadBlockedTimeReadable}
+
{shuffleReadReadable}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 2d13bb6ddde42..37cf2c207ba40 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs
private[spark] object TaskDetailsClassNames {
val SCHEDULER_DELAY = "scheduler_delay"
val TASK_DESERIALIZATION_TIME = "deserialization_time"
+ val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 7584ae79fc920..21487bc24d58a 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
assert(jobB.get() === 100)
}
- ignore("two jobs sharing the same stage") {
+ test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
- // sem2: make sure the first stage is not finished until cancel is issued
+ // twoJobsSharingStageSemaphore:
+ // make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
- val sem2 = new Semaphore(0)
sc = new SparkContext("local[2]", "test")
sc.addSparkListener(new SparkListener {
@@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
// Create two actions that would share the some stages.
val rdd = sc.parallelize(1 to 10, 2).map { i =>
- sem2.acquire()
+ JobCancellationSuite.twoJobsSharingStageSemaphore.acquire()
(i, i)
}.reduceByKey(_+_)
val f1 = rdd.collectAsync()
@@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
future {
sem1.acquire()
f1.cancel()
- sem2.release(10)
+ JobCancellationSuite.twoJobsSharingStageSemaphore.release(10)
}
- // Expect both to fail now.
- // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ // Expect f1 to fail due to cancellation,
intercept[SparkException] { f1.get() }
- intercept[SparkException] { f2.get() }
+ // but f2 should not be affected
+ f2.get()
}
def testCount() {
@@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
object JobCancellationSuite {
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
+ val twoJobsSharingStageSemaphore = new Semaphore(0)
}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index b0a70f012f1f3..af3272692d7a1 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}
+ test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
+ sc = new SparkContext("local", "test")
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ sc.broadcast(Seq(1, 2, 3))
+ }
+ assert(thrown.getMessage.toLowerCase.contains("stopped"))
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val broadcast = sc.broadcast(rdd)
+ val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 8379883e065e7..3fbc1a21d10ed 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers
list.size should be (1)
}
+ test("history file is renamed from inprogress to completed") {
+ val conf = new SparkConf()
+ .set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
+ .set("spark.testing", "true")
+ val provider = new FsHistoryProvider(conf)
+
+ val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS)
+ writeFile(logFile1, true, None,
+ SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"),
+ SparkListenerApplicationEnd(2L)
+ )
+ provider.checkForLogs()
+ val appListBeforeRename = provider.getListing()
+ appListBeforeRename.size should be (1)
+ appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS)
+
+ logFile1.renameTo(new File(testDir, "app1"))
+ provider.checkForLogs()
+ val appListAfterRename = provider.getListing()
+ appListAfterRename.size should be (1)
+ appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS)
+ }
+
private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec],
events: SparkListenerEvent*) = {
val out =
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
index 1a28a9a187cd7..372d7aa453008 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
@@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
@@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 10a39990f80ce..81db66ae17464 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -26,7 +26,16 @@ import org.scalatest.FunSuite
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
+import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf,
+ LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter,
+ TextInputFormat => OldTextInputFormat}
+import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat,
+ CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader}
+import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader,
+ TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat,
+ CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit,
+ FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat}
import org.apache.spark.SharedSparkContext
import org.apache.spark.deploy.SparkHadoopUtil
@@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext {
val fs = FileSystem.getLocal(new Configuration())
val outPath = new Path(fs.getWorkingDirectory, "outdir")
- if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) {
+ if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) {
val taskBytesWritten = new ArrayBuffer[Long]()
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
@@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext {
}
}
}
+
+ test("input metrics with old CombineFileInputFormat") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable],
+ classOf[Text], 2).count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+
+ test("input metrics with new CombineFileInputFormat") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable],
+ classOf[Text], new Configuration()).count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+}
+
+/**
+ * Hadoop 2 has a version of this, but we can't use it for backwards compatibility
+ */
+class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] {
+ override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter)
+ : OldRecordReader[LongWritable, Text] = {
+ new OldCombineFileRecordReader[LongWritable, Text](conf,
+ split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper]
+ .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]])
+ }
+}
+
+class OldCombineTextRecordReaderWrapper(
+ split: OldCombineFileSplit,
+ conf: Configuration,
+ reporter: Reporter,
+ idx: Integer) extends OldRecordReader[LongWritable, Text] {
+
+ val fileSplit = new OldFileSplit(split.getPath(idx),
+ split.getOffset(idx),
+ split.getLength(idx),
+ split.getLocations())
+
+ val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit,
+ conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader]
+
+ override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value)
+ override def createKey(): LongWritable = delegate.createKey()
+ override def createValue(): Text = delegate.createValue()
+ override def getPos(): Long = delegate.getPos
+ override def close(): Unit = delegate.close()
+ override def getProgress(): Float = delegate.getProgress
+}
+
+/**
+ * Hadoop 2 has a version of this, but we can't use it for backwards compatibility
+ */
+class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] {
+ def createRecordReader(split: NewInputSplit, context: TaskAttemptContext)
+ : NewRecordReader[LongWritable, Text] = {
+ new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit],
+ context, classOf[NewCombineTextRecordReaderWrapper])
+ }
}
+
+class NewCombineTextRecordReaderWrapper(
+ split: NewCombineFileSplit,
+ context: TaskAttemptContext,
+ idx: Integer) extends NewRecordReader[LongWritable, Text] {
+
+ val fileSplit = new NewFileSplit(split.getPath(idx),
+ split.getOffset(idx),
+ split.getLength(idx),
+ split.getLocations())
+
+ val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context)
+
+ override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = {
+ delegate.initialize(fileSplit, context)
+ }
+
+ override def nextKeyValue(): Boolean = delegate.nextKeyValue()
+ override def getCurrentKey(): LongWritable = delegate.getCurrentKey
+ override def getCurrentValue(): Text = delegate.getCurrentValue
+ override def getProgress(): Float = delegate.getProgress
+ override def close(): Unit = delegate.close()
+}
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 381ee2d45630f..e33b4bbbb8e4c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}
+
+ test("nested RDDs are not supported (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
+ nestedRDD.count()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("actions cannot be performed inside of transformations (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ rdd.map(x => x * rdd2.count).collect()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
+ val existingRDD = sc.parallelize(1 to 100)
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ existingRDD.count()
+ }
+ assert(thrown.getMessage.contains("shutdown"))
+ }
+
+ test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
+ sc.stop()
+ def assertFails(block: => Any): Unit = {
+ val thrown = intercept[IllegalStateException] {
+ block
+ }
+ assert(thrown.getMessage.contains("stopped"))
+ }
+ assertFails { sc.parallelize(1 to 100) }
+ assertFails { sc.textFile("/nonexistent-path") }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
index dae7bf0e336de..8cf951adb354b 100644
--- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
@@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
// spark.local.dir only contains invalid directories, but that's not a problem since
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
index 10541f878476c..1026cb2aa7cae 100644
--- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
(1 to 100).foreach(eventLoop.post)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert((1 to 100) === buffer.toSeq)
}
eventLoop.stop()
@@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
eventLoop.post(1)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(e === receivedError)
}
eventLoop.stop()
@@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
eventLoop.post(1)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(e === receivedError)
assert(eventLoop.isActive)
}
@@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}.start()
}
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(threadNum * eventsFromEachThread === receivedEventsCount)
}
eventLoop.stop()
@@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
assert(false === eventLoop.isActive)
}
+
+ test("EventLoop: stop in eventThread") {
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ stop()
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(!eventLoop.isActive)
+ }
+ }
}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index b1b8cb44e098b..b2a7e092a0291 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then
for file in $(find . -type f)
do
echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file;
- gpg --print-md MD5 $file > $file.md5;
- gpg --print-md SHA1 $file > $file.sha1
+ if [ $(command -v md5) ]; then
+ # Available on OS X; -q to keep only hash
+ md5 -q $file > $file.md5
+ else
+ # Available on Linux; cut to keep only hash
+ md5sum $file | cut -f1 -d' ' > $file.md5
+ fi
+ shasum -a 1 $file | cut -f1 -d' ' > $file.sha1
done
nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id
diff --git a/docs/configuration.md b/docs/configuration.md
index efbab4085317a..7c5b6d011cfd3 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful
#### Runtime Environment
Property Name Default Meaning
+
+ spark.driver.extraJavaOptions
+ (none)
+
+ A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+
+
+
+ spark.driver.extraClassPath
+ (none)
+
+ Extra classpath entries to append to the classpath of the driver.
+
+
+
+ spark.driver.extraLibraryPath
+ (none)
+
+ Set a special library path to use when launching the driver JVM.
+
+
spark.executor.extraJavaOptions
(none)
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index 2094963392295..ef18cec9371d6 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva
recommendation by measuring the Mean Squared Error of rating prediction.
{% highlight python %}
-from pyspark.mllib.recommendation import ALS
-from numpy import array
+from pyspark.mllib.recommendation import ALS, Rating
# Load and parse the data
data = sc.textFile("data/mllib/als/test.data")
-ratings = data.map(lambda line: array([float(x) for x in line.split(',')]))
+ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2])))
# Build the recommendation model using Alternating Least Squares
rank = 10
@@ -205,10 +204,10 @@ numIterations = 20
model = ALS.train(ratings, rank, numIterations)
# Evaluate the model on training data
-testdata = ratings.map(lambda p: (int(p[0]), int(p[1])))
+testdata = ratings.map(lambda p: (p[0], p[1]))
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
-MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()
+MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
print("Mean Squared Error = " + str(MSE))
{% endhighlight %}
@@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results.
{% highlight python %}
# Build the recommendation model using Alternating Least Squares based on implicit ratings
-model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
+model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01)
{% endhighlight %}
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 0e38fe2144e9f..77c0abbbacbd0 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide
streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 3bd1deaccfafe..14a87f8436984 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell).
Alternatively, if your application is submitted from a machine far from the worker machines (e.g.
locally on your laptop), it is common to use `cluster` mode to minimize network latency between
-the drivers and the executors. Note that `cluster` mode is currently not supported for standalone
-clusters, Mesos clusters, or Python applications.
+the drivers and the executors. Note that `cluster` mode is currently not supported for
+Mesos clusters or Python applications.
For Python applications, simply pass a `.py` file in the place of `` instead of a JAR,
and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`.
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
index 2adc63f7ff30e..387c0e421334b 100644
--- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
@@ -76,7 +76,7 @@ object KafkaWordCountProducer {
val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
- // Zookeper connection properties
+ // Zookeeper connection properties
val props = new Properties()
props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 247d2a5e31a8c..0fbee6e433608 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -33,7 +33,7 @@
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
@@ -71,7 +71,7 @@ public static void main(String[] args) {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -112,11 +112,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test).registerAsTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ cvModel.transform(test).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ ", prediction=" + r.get(3));
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 5b92655e2e838..eaaa344be49c8 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -28,7 +28,7 @@
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
@@ -48,13 +48,13 @@ public static void main(String[] args) {
// Prepare training data.
// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans
- // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ // into DataFrames, where it uses the bean metadata to infer the schema.
List localTraining = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -94,14 +94,14 @@ public static void main(String[] args) {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
// column since we renamed the lr.scoreCol parameter previously.
- model2.transform(test).registerAsTable("results");
- SchemaRDD results =
+ model2.transform(test).registerTempTable("results");
+ DataFrame results =
jsql.sql("SELECT features, label, probability, prediction FROM results");
for (Row r: results.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 74db449fada7d..82d665a3e1386 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -29,7 +29,7 @@
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
@@ -54,7 +54,7 @@ public static void main(String[] args) {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -79,11 +79,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
- model.transform(test).registerAsTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ model.transform(test).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ ", prediction=" + r.get(3));
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index b70804635d5c9..8defb769ffaaf 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -26,9 +26,9 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
public class JavaSparkSQL {
public static class Person implements Serializable {
@@ -74,13 +74,13 @@ public Person call(String line) {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class);
+ DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
- SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
- // The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+ // The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
List teenagerNames = teenagers.toJavaRDD().map(new Function() {
@Override
@@ -93,17 +93,17 @@ public String call(Row row) {
}
System.out.println("=== Data source: Parquet File ===");
- // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information.
+ // DataFrames can be saved as parquet files, maintaining the schema information.
schemaPeople.saveAsParquetFile("people.parquet");
// Read in the parquet file created above.
// Parquet files are self-describing so the schema is preserved.
- // The result of loading a parquet file is also a JavaSchemaRDD.
- SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet");
+ // The result of loading a parquet file is also a DataFrame.
+ DataFrame parquetFile = sqlCtx.parquetFile("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
- SchemaRDD teenagers2 =
+ DataFrame teenagers2 =
sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
teenagerNames = teenagers2.toJavaRDD().map(new Function() {
@Override
@@ -119,8 +119,8 @@ public String call(Row row) {
// A JSON dataset is pointed by path.
// The path can be either a single text file or a directory storing text files.
String path = "examples/src/main/resources/people.json";
- // Create a JavaSchemaRDD from the file(s) pointed by path
- SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path);
+ // Create a DataFrame from the file(s) pointed by path
+ DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path);
// Because the schema of a JSON dataset is automatically inferred, to write queries,
// it is better to take a look at what is the schema.
@@ -130,13 +130,13 @@ public String call(Row row) {
// |-- age: IntegerType
// |-- name: StringType
- // Register this JavaSchemaRDD as a table.
+ // Register this DataFrame as a table.
peopleFromJsonFile.registerTempTable("people");
// SQL statements can be run by using the sql methods provided by sqlCtx.
- SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
- // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations.
+ // The results of SQL queries are DataFrame and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
teenagerNames = teenagers3.toJavaRDD().map(new Function() {
@Override
@@ -146,14 +146,14 @@ public String call(Row row) {
System.out.println(name);
}
- // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by
+ // Alternatively, a DataFrame can be created for a JSON dataset represented by
// a RDD[String] storing one JSON object per string.
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData);
- SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd());
+ DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd());
- // Take a look at the schema of this new JavaSchemaRDD.
+ // Take a look at the schema of this new DataFrame.
peopleFromJsonRDD.printSchema();
// The schema of anotherPeople is ...
// root
@@ -164,7 +164,7 @@ public String call(Row row) {
peopleFromJsonRDD.registerTempTable("people2");
- SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
+ DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
index 540dae785f6ea..b5a70db2b9a3c 100644
--- a/examples/src/main/python/mllib/dataset_example.py
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -16,7 +16,7 @@
#
"""
-An example of how to use SchemaRDD as a dataset for ML. Run with::
+An example of how to use DataFrame as a dataset for ML. Run with::
bin/spark-submit examples/src/main/python/mllib/dataset_example.py
"""
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index d2c5ca48c6cb8..7f5c68e3d0fe2 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -30,18 +30,18 @@
some_rdd = sc.parallelize([Row(name="John", age=19),
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
- # Infer schema from the first row, create a SchemaRDD and print the schema
- some_schemardd = sqlContext.inferSchema(some_rdd)
- some_schemardd.printSchema()
+ # Infer schema from the first row, create a DataFrame and print the schema
+ some_df = sqlContext.inferSchema(some_rdd)
+ some_df.printSchema()
# Another RDD is created from a list of tuples
another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)])
# Schema with two fields - person_name and person_age
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
- # Create a SchemaRDD by applying the schema to the RDD and print the schema
- another_schemardd = sqlContext.applySchema(another_rdd, schema)
- another_schemardd.printSchema()
+ # Create a DataFrame by applying the schema to the RDD and print the schema
+ another_df = sqlContext.applySchema(another_rdd, schema)
+ another_df.printSchema()
# root
# |-- age: integer (nullable = true)
# |-- name: string (nullable = true)
@@ -49,7 +49,7 @@
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
- # Create a SchemaRDD from the file(s) pointed to by path
+ # Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
# root
# |-- person_name: string (nullable = false)
@@ -61,7 +61,7 @@
# |-- age: IntegerType
# |-- name: StringType
- # Register this SchemaRDD as a table.
+ # Register this DataFrame as a table.
people.registerAsTable("people")
# SQL statements can be run by using the sql methods provided by sqlContext
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index d8c7ef38ee46d..283bb80f1c788 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -18,7 +18,6 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
@@ -101,7 +100,7 @@ object CrossValidatorExample {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test)
- .select('id, 'text, 'score, 'prediction)
+ .select("id", "text", "score", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
new file mode 100644
index 0000000000000..b7885829459a3
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.recommendation.ALS
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
+ * Run with
+ * {{{
+ * bin/run-example ml.MovieLensALS
+ * }}}
+ */
+object MovieLensALS {
+
+ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
+
+ object Rating {
+ def parseRating(str: String): Rating = {
+ val fields = str.split("::")
+ assert(fields.size == 4)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
+ }
+ }
+
+ case class Movie(movieId: Int, title: String, genres: Seq[String])
+
+ object Movie {
+ def parseMovie(str: String): Movie = {
+ val fields = str.split("::")
+ assert(fields.size == 3)
+ Movie(fields(0).toInt, fields(1), fields(2).split("|"))
+ }
+ }
+
+ case class Params(
+ ratings: String = null,
+ movies: String = null,
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ rank: Int = 10,
+ numBlocks: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("MovieLensALS") {
+ head("MovieLensALS: an example app for ALS on MovieLens data.")
+ opt[String]("ratings")
+ .required()
+ .text("path to a MovieLens dataset of ratings")
+ .action((x, c) => c.copy(ratings = x))
+ opt[String]("movies")
+ .required()
+ .text("path to a MovieLens dataset of movies")
+ .action((x, c) => c.copy(movies = x))
+ opt[Int]("rank")
+ .text(s"rank, default: ${defaultParams.rank}}")
+ .action((x, c) => c.copy(rank = x))
+ opt[Int]("maxIter")
+ .text(s"max number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Int]("numBlocks")
+ .text(s"number of blocks, default: ${defaultParams.numBlocks}")
+ .action((x, c) => c.copy(numBlocks = x))
+ note(
+ """
+ |Example command line to run this app:
+ |
+ | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
+ | examples/target/scala-*/spark-examples-*.jar \
+ | --rank 10 --maxIter 15 --regParam 0.1 \
+ | --movies path/to/movielens/movies.dat \
+ | --ratings path/to/movielens/ratings.dat
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
+
+ val numRatings = ratings.count()
+ val numUsers = ratings.map(_.userId).distinct().count()
+ val numMovies = ratings.map(_.movieId).distinct().count()
+
+ println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
+
+ val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
+ val training = splits(0).cache()
+ val test = splits(1).cache()
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ println(s"Training: $numTraining, test: $numTest.")
+
+ ratings.unpersist(blocking = false)
+
+ val als = new ALS()
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRank(params.rank)
+ .setMaxIter(params.maxIter)
+ .setRegParam(params.regParam)
+ .setNumBlocks(params.numBlocks)
+
+ val model = als.fit(training)
+
+ val predictions = model.transform(test).cache()
+
+ // Evaluate the model.
+ // TODO: Create an evaluator to compute RMSE.
+ val mse = predictions.select("rating", "prediction").rdd
+ .flatMap { case Row(rating: Float, prediction: Float) =>
+ val err = rating.toDouble - prediction
+ val err2 = err * err
+ if (err2.isNaN) {
+ None
+ } else {
+ Some(err2)
+ }
+ }.mean()
+ val rmse = math.sqrt(mse)
+ println(s"Test RMSE = $rmse.")
+
+ // Inspect false positives.
+ predictions.registerTempTable("prediction")
+ sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
+ sqlContext.sql(
+ """
+ |SELECT userId, prediction.movieId, title, rating, prediction
+ | FROM prediction JOIN movie ON prediction.movieId = movie.movieId
+ | WHERE rating <= 1 AND prediction >= 4
+ | LIMIT 100
+ """.stripMargin)
+ .collect()
+ .foreach(println)
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a2adff929cb..95cc9801eaeb9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -18,7 +18,6 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -42,7 +41,7 @@ object SimpleParamsExample {
// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
- // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ // into DataFrames, where it uses the bean metadata to infer the schema.
val training = sparkContext.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -92,7 +91,7 @@ object SimpleParamsExample {
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
// column since we renamed the lr.scoreCol parameter previously.
model2.transform(test)
- .select('features, 'label, 'probability, 'prediction)
+ .select("features", "label", "probability", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index b9a6ef0229def..065db62b0f5ed 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -20,7 +20,6 @@ package org.apache.spark.examples.ml
import scala.beans.BeanInfo
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
@@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline {
// Make predictions on test documents.
model.transform(test)
- .select('id, 'text, 'score, 'prediction)
+ .select("id", "text", "score", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
index f8d83f4ec7327..f229a58985a3e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+import org.apache.spark.sql.{Row, SQLContext, DataFrame}
/**
- * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
* }}}
@@ -47,7 +47,7 @@ object DatasetExample {
val defaultParams = Params()
val parser = new OptionParser[Params]("DatasetExample") {
- head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ head("Dataset: an example app using DataFrame as a Dataset for ML.")
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
@@ -80,20 +80,20 @@ object DatasetExample {
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")
- // Convert input data to SchemaRDD explicitly.
- val schemaRDD: SchemaRDD = origData
- println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
- println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+ // Convert input data to DataFrame explicitly.
+ val df: DataFrame = origData.toDF
+ println(s"Inferred schema:\n${df.schema.prettyJson}")
+ println(s"Converted to DataFrame with ${df.count()} records")
- // Select columns, using implicit conversion to SchemaRDD.
- val labelsSchemaRDD: SchemaRDD = origData.select('label)
- val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ // Select columns, using implicit conversion to DataFrames.
+ val labelsDf: DataFrame = origData.select("label")
+ val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")
- val featuresSchemaRDD: SchemaRDD = origData.select('features)
- val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featuresDf: DataFrame = origData.select("features")
+ val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
@@ -103,13 +103,13 @@ object DatasetExample {
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
- schemaRDD.saveAsParquetFile(outputDir)
+ df.saveAsParquetFile(outputDir)
println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.parquetFile(outputDir)
println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
- val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index 2e98b2dc30b80..a5d7f262581f5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -19,6 +19,8 @@ package org.apache.spark.examples.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.dsl._
+import org.apache.spark.sql.dsl.literals._
// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
@@ -54,7 +56,7 @@ object RDDRelation {
rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
// Queries can also be written using a LINQ-like Scala DSL.
- rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println)
+ rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
// Write out an RDD as a parquet file.
rdd.saveAsParquetFile("pair.parquet")
@@ -63,7 +65,7 @@ object RDDRelation {
val parquetFile = sqlContext.parquetFile("pair.parquet")
// Queries can be run using the DSL on parequet files just like the original RDD.
- parquetFile.where('key === 1).select('value as 'a).collect().foreach(println)
+ parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println)
// These files can also be registered as tables.
parquetFile.registerTempTable("parquetFile")
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
index 897c7ee12a436..f1550ac2e18ad 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
@@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
* partitioner that allows co-partitioning with `partitionsRDD`.
*/
override val partitioner =
- partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+ partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size)))
override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 9da0064104fb6..ed9876b8dc21c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("non-default number of edge partitions") {
+ val n = 10
+ val defaultParallelism = 3
+ val numEdgePartitions = 4
+ assert(defaultParallelism != numEdgePartitions)
+ val conf = new org.apache.spark.SparkConf()
+ .set("spark.default.parallelism", defaultParallelism.toString)
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)),
+ numEdgePartitions)
+ val graph = Graph.fromEdgeTuples(edges, 1)
+ val neighborAttrSums = graph.mapReduceTriplets[Int](
+ et => Iterator((et.dstId, et.srcAttr)), _ + _)
+ assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n)))
+ } finally {
+ sc.stop()
+ }
+ }
+
}
diff --git a/make-distribution.sh b/make-distribution.sh
index 4e2f400be3053..0adca7851819b 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -115,7 +115,7 @@ if which git &>/dev/null; then
unset GITREV
fi
-if ! which $MVN &>/dev/null; then
+if ! which "$MVN" &>/dev/null; then
echo -e "Could not locate Maven command: '$MVN'."
echo -e "Specify the Maven command with the --mvn flag"
exit -1;
@@ -171,13 +171,16 @@ cd "$SPARK_HOME"
export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
-BUILD_COMMAND="$MVN clean package -DskipTests $@"
+# Store the command as an array because $MVN variable might have spaces in it.
+# Normal quoting tricks don't work.
+# See: http://mywiki.wooledge.org/BashFAQ/050
+BUILD_COMMAND=("$MVN" clean package -DskipTests $@)
# Actually build the jar
echo -e "\nBuilding with..."
-echo -e "\$ $BUILD_COMMAND\n"
+echo -e "\$ ${BUILD_COMMAND[@]}\n"
-${BUILD_COMMAND}
+"${BUILD_COMMAND[@]}"
# Make directories
rm -rf "$DISTDIR"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 77d230eb4a122..bc3defe968afd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -21,7 +21,7 @@ import scala.annotation.varargs
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @return fitted model
*/
@varargs
- def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
fit(dataset, map)
}
@@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMap parameter map
* @return fitted model
*/
- def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+ def fit(dataset: DataFrame, paramMap: ParamMap): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
index db563dd550e56..d2ca2e6871e6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+ def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index ad6fed178fae9..fe39cd1bc0bd2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] {
* @param paramMap parameter map
* @return fitted pipeline
*/
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val theStages = map(stages)
@@ -162,7 +162,7 @@ class PipelineModel private[ml] (
}
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
transformSchema(dataset.schema, map, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index af56f9c435351..b233bff08305c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -22,9 +22,9 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/**
@@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params {
* @return transformed dataset
*/
@varargs
- def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
paramPairs.foreach(map.put(_))
transform(dataset, map)
@@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
}
/**
@@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
StructType(outputFields)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
- dataset.select(Star(None), udf as map(outputCol))
+ dataset.select($"*", callUDF(
+ this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8c570812f8316..eeb6301c3f64a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ val instances = dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.persist(StorageLevel.MEMORY_AND_DISK)
@@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] (
validateAndTransformSchema(schema, paramMap, fitting = false)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val score: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
@@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] (
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
- dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
- .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 12473cb2b5719..1979ab9eb6516 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.sql.{Row, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
@@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
require(labelType == DoubleType,
s"Label column ${map(labelCol)} must be double type but found $labelType")
- import dataset.sqlContext._
- val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
.map { case Row(score: Double, label: Double) =>
(score, label)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 72825f6e02182..e7bdb070c8193 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val input = dataset.select(map(inputCol).attr)
- .map { case Row(v: Vector) =>
- v
- }
+ val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
Params.inheritValues(map, this, model)
@@ -83,14 +79,13 @@ class StandardScalerModel private[ml] (
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
- dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
new file mode 100644
index 0000000000000..f6437c7fbc8ed
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -0,0 +1,970 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.{util => ju}
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.netlib.util.intW
+
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.dsl._
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Common params for ALS.
+ */
+private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+ with HasPredictionCol {
+
+ /** Param for rank of the matrix factorization. */
+ val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+ def getRank: Int = get(rank)
+
+ /** Param for number of user blocks. */
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+ def getNumUserBlocks: Int = get(numUserBlocks)
+
+ /** Param for number of item blocks. */
+ val numItemBlocks =
+ new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+ def getNumItemBlocks: Int = get(numItemBlocks)
+
+ /** Param to decide whether to use implicit preference. */
+ val implicitPrefs =
+ new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+ def getImplicitPrefs: Boolean = get(implicitPrefs)
+
+ /** Param for the alpha parameter in the implicit preference formulation. */
+ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+ def getAlpha: Double = get(alpha)
+
+ /** Param for the column name for user ids. */
+ val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+ def getUserCol: String = get(userCol)
+
+ /** Param for the column name for item ids. */
+ val itemCol =
+ new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+ def getItemCol: String = get(itemCol)
+
+ /** Param for the column name for ratings. */
+ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+ def getRatingCol: String = get(ratingCol)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @param paramMap extra params
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ assert(schema(map(userCol)).dataType == IntegerType)
+ assert(schema(map(itemCol)).dataType== IntegerType)
+ val ratingType = schema(map(ratingCol)).dataType
+ assert(ratingType == FloatType || ratingType == DoubleType)
+ val predictionColName = map(predictionCol)
+ assert(!schema.fieldNames.contains(predictionColName),
+ s"Prediction column $predictionColName already exists.")
+ val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false)
+ StructType(newFields)
+ }
+}
+
+/**
+ * Model fitted by ALS.
+ */
+class ALSModel private[ml] (
+ override val parent: ALS,
+ override val fittingParamMap: ParamMap,
+ k: Int,
+ userFactors: RDD[(Int, Array[Float])],
+ itemFactors: RDD[(Int, Array[Float])])
+ extends Model[ALSModel] with ALSParams {
+
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ import dataset.sqlContext._
+ import org.apache.spark.ml.recommendation.ALSModel.Factor
+ val map = this.paramMap ++ paramMap
+ // TODO: Add DSL to simplify the code here.
+ val instanceTable = s"instance_$uid"
+ val userTable = s"user_$uid"
+ val itemTable = s"item_$uid"
+ val instances = dataset.as(instanceTable)
+ val users = userFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(userTable)
+ val items = itemFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(itemTable)
+ val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
+ if (userFeatures != null && itemFeatures != null) {
+ blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+ val inputColumns = dataset.schema.fieldNames
+ val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
+ .as(map(predictionCol))
+ val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
+ instances
+ .join(users, Column(map(userCol)) === $"$userTable.id", "left")
+ .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
+ .select(outputColumns: _*)
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private object ALSModel {
+ /** Case class to convert factors to SchemaRDDs */
+ private case class Factor(id: Int, features: Seq[Float])
+}
+
+/**
+ * Alternating Least Squares (ALS) matrix factorization.
+ *
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
+ * This is a blocked implementation of the ALS factorization algorithm that groups the two sets
+ * of factors (referred to as "users" and "products") into blocks and reduces communication by only
+ * sending one copy of each user vector to each product block on each iteration, and only for the
+ * product blocks that need that user's feature vector. This is achieved by pre-computing some
+ * information about the ratings matrix to determine the "out-links" of each user (which blocks of
+ * products it will contribute to) and "in-link" information for each product (which of the feature
+ * vectors it receives from each user block it will depend on). This allows us to send only an
+ * array of feature vectors between each user block and product block, and have the product block
+ * find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
+ * indicated user
+ * preferences rather than explicit ratings given to items.
+ */
+class ALS extends Estimator[ALSModel] with ALSParams {
+
+ import org.apache.spark.ml.recommendation.ALS.Rating
+
+ def setRank(value: Int): this.type = set(rank, value)
+ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
+ def setAlpha(value: Double): this.type = set(alpha, value)
+ def setUserCol(value: String): this.type = set(userCol, value)
+ def setItemCol(value: String): this.type = set(itemCol, value)
+ def setRatingCol(value: String): this.type = set(ratingCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** Sets both numUserBlocks and numItemBlocks to the specific value. */
+ def setNumBlocks(value: Int): this.type = {
+ setNumUserBlocks(value)
+ setNumItemBlocks(value)
+ this
+ }
+
+ setMaxIter(20)
+ setRegParam(1.0)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
+ val map = this.paramMap ++ paramMap
+ val ratings = dataset
+ .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
+ .map { row =>
+ new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
+ numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
+ maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
+ alpha = map(alpha))
+ val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private[recommendation] object ALS extends Logging {
+
+ /** Rating class for better code readability. */
+ private[recommendation] case class Rating(user: Int, item: Int, rating: Float)
+
+ /** Cholesky solver for least square problems. */
+ private[recommendation] class CholeskySolver {
+
+ private val upper = "U"
+ private val info = new intW(0)
+
+ /**
+ * Solves a least squares problem with L2 regularization:
+ *
+ * min norm(A x - b)^2^ + lambda * n * norm(x)^2^
+ *
+ * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
+ * @param lambda regularization constant, which will be scaled by n
+ * @return the solution x
+ */
+ def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ val k = ne.k
+ // Add scaled lambda to the diagonals of AtA.
+ val scaledlambda = lambda * ne.n
+ var i = 0
+ var j = 2
+ while (i < ne.triK) {
+ ne.ata(i) += scaledlambda
+ i += j
+ j += 1
+ }
+ lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info)
+ val code = info.`val`
+ assert(code == 0, s"lapack.dppsv returned $code.")
+ val x = new Array[Float](k)
+ i = 0
+ while (i < k) {
+ x(i) = ne.atb(i).toFloat
+ i += 1
+ }
+ ne.reset()
+ x
+ }
+ }
+
+ /** Representing a normal equation (ALS' subproblem). */
+ private[recommendation] class NormalEquation(val k: Int) extends Serializable {
+
+ /** Number of entries in the upper triangular part of a k-by-k matrix. */
+ val triK = k * (k + 1) / 2
+ /** A^T^ * A */
+ val ata = new Array[Double](triK)
+ /** A^T^ * b */
+ val atb = new Array[Double](k)
+ /** Number of observations. */
+ var n = 0
+
+ private val da = new Array[Double](k)
+ private val upper = "U"
+
+ private def copyToDouble(a: Array[Float]): Unit = {
+ var i = 0
+ while (i < k) {
+ da(i) = a(i)
+ i += 1
+ }
+ }
+
+ /** Adds an observation. */
+ def add(a: Array[Float], b: Float): this.type = {
+ require(a.size == k)
+ copyToDouble(a)
+ blas.dspr(upper, k, 1.0, da, 1, ata)
+ blas.daxpy(k, b.toDouble, da, 1, atb, 1)
+ n += 1
+ this
+ }
+
+ /**
+ * Adds an observation with implicit feedback. Note that this does not increment the counter.
+ */
+ def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
+ require(a.size == k)
+ // Extension to the original paper to handle b < 0. confidence is a function of |b| instead
+ // so that it is never negative.
+ val confidence = 1.0 + alpha * math.abs(b)
+ copyToDouble(a)
+ blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
+ // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
+ if (b > 0) {
+ blas.daxpy(k, confidence, da, 1, atb, 1)
+ }
+ this
+ }
+
+ /** Merges another normal equation object. */
+ def merge(other: NormalEquation): this.type = {
+ require(other.k == k)
+ blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
+ blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
+ n += other.n
+ this
+ }
+
+ /** Resets everything to zero, which should be called after each solve. */
+ def reset(): Unit = {
+ ju.Arrays.fill(ata, 0.0)
+ ju.Arrays.fill(atb, 0.0)
+ n = 0
+ }
+ }
+
+ /**
+ * Implementation of the ALS algorithm.
+ */
+ private def train(
+ ratings: RDD[Rating],
+ rank: Int = 10,
+ numUserBlocks: Int = 10,
+ numItemBlocks: Int = 10,
+ maxIter: Int = 10,
+ regParam: Double = 1.0,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = {
+ val userPart = new HashPartitioner(numUserBlocks)
+ val itemPart = new HashPartitioner(numItemBlocks)
+ val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
+ val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
+ val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
+ val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
+ // materialize blockRatings and user blocks
+ userOutBlocks.count()
+ val swappedBlockRatings = blockRatings.map {
+ case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
+ ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
+ }
+ val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart)
+ // materialize item blocks
+ itemOutBlocks.count()
+ var userFactors = initialize(userInBlocks, rank)
+ var itemFactors = initialize(itemInBlocks, rank)
+ if (implicitPrefs) {
+ for (iter <- 1 to maxIter) {
+ userFactors.setName(s"userFactors-$iter").persist()
+ val previousItemFactors = itemFactors
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder, implicitPrefs, alpha)
+ previousItemFactors.unpersist()
+ itemFactors.setName(s"itemFactors-$iter").persist()
+ val previousUserFactors = userFactors
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder, implicitPrefs, alpha)
+ previousUserFactors.unpersist()
+ }
+ } else {
+ for (iter <- 0 until maxIter) {
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder)
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder)
+ }
+ }
+ val userIdAndFactors = userInBlocks
+ .mapValues(_.srcIds)
+ .join(userFactors)
+ .values
+ .setName("userFactors")
+ .cache()
+ userIdAndFactors.count()
+ itemFactors.unpersist()
+ val itemIdAndFactors = itemInBlocks
+ .mapValues(_.srcIds)
+ .join(itemFactors)
+ .values
+ .setName("itemFactors")
+ .cache()
+ itemIdAndFactors.count()
+ userInBlocks.unpersist()
+ userOutBlocks.unpersist()
+ itemInBlocks.unpersist()
+ itemOutBlocks.unpersist()
+ blockRatings.unpersist()
+ val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ (userOutput, itemOutput)
+ }
+
+ /**
+ * Factor block that stores factors (Array[Float]) in an Array.
+ */
+ private type FactorBlock = Array[Array[Float]]
+
+ /**
+ * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to
+ * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the
+ * src factors in this block to send to dst block 0.
+ */
+ private type OutBlock = Array[Array[Int]]
+
+ /**
+ * In-link block for computing src (user/item) factors. This includes the original src IDs
+ * of the elements within this block as well as encoded dst (item/user) indices and corresponding
+ * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original
+ * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices.
+ * For example, if we have an in-link record
+ *
+ * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0},
+ *
+ * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which
+ * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3).
+ *
+ * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can
+ * compute src factors one after another using only one normal equation instance.
+ *
+ * @param srcIds src ids (ordered)
+ * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and
+ * ratings are associated with srcIds(i).
+ * @param dstEncodedIndices encoded dst indices
+ * @param ratings ratings
+ *
+ * @see [[LocalIndexEncoder]]
+ */
+ private[recommendation] case class InBlock(
+ srcIds: Array[Int],
+ dstPtrs: Array[Int],
+ dstEncodedIndices: Array[Int],
+ ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = ratings.size
+
+ require(dstEncodedIndices.size == size)
+ require(dstPtrs.size == srcIds.size + 1)
+ }
+
+ /**
+ * Initializes factors randomly given the in-link blocks.
+ *
+ * @param inBlocks in-link blocks
+ * @param rank rank
+ * @return initialized factor blocks
+ */
+ private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
+ // Choose a unit vector uniformly at random from the unit sphere, but from the
+ // "first quadrant" where all elements are nonnegative. This can be done by choosing
+ // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+ // This appears to create factorizations that have a slightly better reconstruction
+ // (<1%) compared picking elements uniformly at random in [0,1].
+ inBlocks.map { case (srcBlockId, inBlock) =>
+ val random = new XORShiftRandom(srcBlockId)
+ val factors = Array.fill(inBlock.srcIds.size) {
+ val factor = Array.fill(rank)(random.nextGaussian().toFloat)
+ val nrm = blas.snrm2(rank, factor, 1)
+ blas.sscal(rank, 1.0f / nrm, factor, 1)
+ factor
+ }
+ (srcBlockId, factors)
+ }
+ }
+
+ /**
+ * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
+ */
+ private[recommendation]
+ case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = srcIds.size
+ require(dstIds.size == size)
+ require(ratings.size == size)
+ }
+
+ /**
+ * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
+ */
+ private[recommendation] class RatingBlockBuilder extends Serializable {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstIds = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+ var size = 0
+
+ /** Adds a rating. */
+ def add(r: Rating): this.type = {
+ size += 1
+ srcIds += r.user
+ dstIds += r.item
+ ratings += r.rating
+ this
+ }
+
+ /** Merges another [[RatingBlockBuilder]]. */
+ def merge(other: RatingBlock): this.type = {
+ size += other.srcIds.size
+ srcIds ++= other.srcIds
+ dstIds ++= other.dstIds
+ ratings ++= other.ratings
+ this
+ }
+
+ /** Builds a [[RatingBlock]]. */
+ def build(): RatingBlock = {
+ RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
+ }
+ }
+
+ /**
+ * Partitions raw ratings into blocks.
+ *
+ * @param ratings raw ratings
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ *
+ * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
+ */
+ private def partitionRatings(
+ ratings: RDD[Rating],
+ srcPart: Partitioner,
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
+
+ /* The implementation produces the same result as the following but generates less objects.
+
+ ratings.map { r =>
+ ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r)
+ }.aggregateByKey(new RatingBlockBuilder)(
+ seqOp = (b, r) => b.add(r),
+ combOp = (b0, b1) => b0.merge(b1.build()))
+ .mapValues(_.build())
+ */
+
+ val numPartitions = srcPart.numPartitions * dstPart.numPartitions
+ ratings.mapPartitions { iter =>
+ val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
+ iter.flatMap { r =>
+ val srcBlockId = srcPart.getPartition(r.user)
+ val dstBlockId = dstPart.getPartition(r.item)
+ val idx = srcBlockId + srcPart.numPartitions * dstBlockId
+ val builder = builders(idx)
+ builder.add(r)
+ if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
+ builders(idx) = new RatingBlockBuilder
+ Iterator.single(((srcBlockId, dstBlockId), builder.build()))
+ } else {
+ Iterator.empty
+ }
+ } ++ {
+ builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
+ val srcBlockId = idx % srcPart.numPartitions
+ val dstBlockId = idx / srcPart.numPartitions
+ ((srcBlockId, dstBlockId), block.build())
+ }
+ }
+ }.groupByKey().mapValues { blocks =>
+ val builder = new RatingBlockBuilder
+ blocks.foreach(builder.merge)
+ builder.build()
+ }.setName("ratingBlocks")
+ }
+
+ /**
+ * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
+ * @param encoder encoder for dst indices
+ */
+ private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+
+ /**
+ * Adds a dst block of (srcId, dstLocalIndex, rating) tuples.
+ *
+ * @param dstBlockId dst block ID
+ * @param srcIds original src IDs
+ * @param dstLocalIndices dst local indices
+ * @param ratings ratings
+ */
+ def add(
+ dstBlockId: Int,
+ srcIds: Array[Int],
+ dstLocalIndices: Array[Int],
+ ratings: Array[Float]): this.type = {
+ val sz = srcIds.size
+ require(dstLocalIndices.size == sz)
+ require(ratings.size == sz)
+ this.srcIds ++= srcIds
+ this.ratings ++= ratings
+ var j = 0
+ while (j < sz) {
+ this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
+ j += 1
+ }
+ this
+ }
+
+ /** Builds a [[UncompressedInBlock]]. */
+ def build(): UncompressedInBlock = {
+ new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
+ }
+ }
+
+ /**
+ * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
+ */
+ private[recommendation] class UncompressedInBlock(
+ val srcIds: Array[Int],
+ val dstEncodedIndices: Array[Int],
+ val ratings: Array[Float]) {
+
+ /** Size the of block. */
+ def size: Int = srcIds.size
+
+ /**
+ * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
+ * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
+ * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
+ */
+ def compress(): InBlock = {
+ val sz = size
+ assert(sz > 0, "Empty in-link block should not exist.")
+ sort()
+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
+ val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
+ var preSrcId = srcIds(0)
+ uniqueSrcIdsBuilder += preSrcId
+ var curCount = 1
+ var i = 1
+ var j = 0
+ while (i < sz) {
+ val srcId = srcIds(i)
+ if (srcId != preSrcId) {
+ uniqueSrcIdsBuilder += srcId
+ dstCountsBuilder += curCount
+ preSrcId = srcId
+ j += 1
+ curCount = 0
+ }
+ curCount += 1
+ i += 1
+ }
+ dstCountsBuilder += curCount
+ val uniqueSrcIds = uniqueSrcIdsBuilder.result()
+ val numUniqueSrdIds = uniqueSrcIds.size
+ val dstCounts = dstCountsBuilder.result()
+ val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
+ var sum = 0
+ i = 0
+ while (i < numUniqueSrdIds) {
+ sum += dstCounts(i)
+ i += 1
+ dstPtrs(i) = sum
+ }
+ InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
+ }
+
+ private def sort(): Unit = {
+ val sz = size
+ // Since there might be interleaved log messages, we insert a unique id for easy pairing.
+ val sortId = Utils.random.nextInt()
+ logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
+ val start = System.nanoTime()
+ val sorter = new Sorter(new UncompressedInBlockSort)
+ sorter.sort(this, 0, size, Ordering[IntWrapper])
+ val duration = (System.nanoTime() - start) / 1e9
+ logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
+ }
+ }
+
+ /**
+ * A wrapper that holds a primitive integer key.
+ *
+ * @see [[UncompressedInBlockSort]]
+ */
+ private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
+ override def compare(that: IntWrapper): Int = {
+ key.compare(that.key)
+ }
+ }
+
+ /**
+ * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
+ */
+ private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] {
+
+ override def newKey(): IntWrapper = new IntWrapper()
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int,
+ reuse: IntWrapper): IntWrapper = {
+ if (reuse == null) {
+ new IntWrapper(data.srcIds(pos))
+ } else {
+ reuse.key = data.srcIds(pos)
+ reuse
+ }
+ }
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int): IntWrapper = {
+ getKey(data, pos, null)
+ }
+
+ private def swapElements[@specialized(Int, Float) T](
+ data: Array[T],
+ pos0: Int,
+ pos1: Int): Unit = {
+ val tmp = data(pos0)
+ data(pos0) = data(pos1)
+ data(pos1) = tmp
+ }
+
+ override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = {
+ swapElements(data.srcIds, pos0, pos1)
+ swapElements(data.dstEncodedIndices, pos0, pos1)
+ swapElements(data.ratings, pos0, pos1)
+ }
+
+ override def copyRange(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int,
+ length: Int): Unit = {
+ System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
+ System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
+ System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
+ }
+
+ override def allocate(length: Int): UncompressedInBlock = {
+ new UncompressedInBlock(
+ new Array[Int](length), new Array[Int](length), new Array[Float](length))
+ }
+
+ override def copyElement(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int): Unit = {
+ dst.srcIds(dstPos) = src.srcIds(srcPos)
+ dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
+ dst.ratings(dstPos) = src.ratings(srcPos)
+ }
+ }
+
+ /**
+ * Creates in-blocks and out-blocks from rating blocks.
+ * @param prefix prefix for in/out-block names
+ * @param ratingBlocks rating blocks
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ * @return (in-blocks, out-blocks)
+ */
+ private def makeBlocks(
+ prefix: String,
+ ratingBlocks: RDD[((Int, Int), RatingBlock)],
+ srcPart: Partitioner,
+ dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = {
+ val inBlocks = ratingBlocks.map {
+ case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
+ // The implementation is a faster version of
+ // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
+ val start = System.nanoTime()
+ val dstIdSet = new OpenHashSet[Int](1 << 20)
+ dstIds.foreach(dstIdSet.add)
+ val sortedDstIds = new Array[Int](dstIdSet.size)
+ var i = 0
+ var pos = dstIdSet.nextPos(0)
+ while (pos != -1) {
+ sortedDstIds(i) = dstIdSet.getValue(pos)
+ pos = dstIdSet.nextPos(pos + 1)
+ i += 1
+ }
+ assert(i == dstIdSet.size)
+ ju.Arrays.sort(sortedDstIds)
+ val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
+ i = 0
+ while (i < sortedDstIds.size) {
+ dstIdToLocalIndex.update(sortedDstIds(i), i)
+ i += 1
+ }
+ logDebug(
+ "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
+ val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
+ (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
+ }.groupByKey(new HashPartitioner(srcPart.numPartitions))
+ .mapValues { iter =>
+ val builder =
+ new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions))
+ iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
+ builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
+ }
+ builder.build().compress()
+ }.setName(prefix + "InBlocks").cache()
+ val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
+ val encoder = new LocalIndexEncoder(dstPart.numPartitions)
+ val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
+ var i = 0
+ val seen = new Array[Boolean](dstPart.numPartitions)
+ while (i < srcIds.size) {
+ var j = dstPtrs(i)
+ ju.Arrays.fill(seen, false)
+ while (j < dstPtrs(i + 1)) {
+ val dstBlockId = encoder.blockId(dstEncodedIndices(j))
+ if (!seen(dstBlockId)) {
+ activeIds(dstBlockId) += i // add the local index in this out-block
+ seen(dstBlockId) = true
+ }
+ j += 1
+ }
+ i += 1
+ }
+ activeIds.map { x =>
+ x.result()
+ }
+ }.setName(prefix + "OutBlocks").cache()
+ (inBlocks, outBlocks)
+ }
+
+ /**
+ * Compute dst factors by constructing and solving least square problems.
+ *
+ * @param srcFactorBlocks src factors
+ * @param srcOutBlocks src out-blocks
+ * @param dstInBlocks dst in-blocks
+ * @param rank rank
+ * @param regParam regularization constant
+ * @param srcEncoder encoder for src local indices
+ * @param implicitPrefs whether to use implicit preference
+ * @param alpha the alpha constant in the implicit preference formulation
+ *
+ * @return dst factors
+ */
+ private def computeFactors(
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock)],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: LocalIndexEncoder,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
+ val numSrcBlocks = srcFactorBlocks.partitions.size
+ val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
+ val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
+ case (srcBlockId, (srcOutBlock, srcFactors)) =>
+ srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
+ (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
+ }
+ }
+ val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
+ dstInBlocks.join(merged).mapValues {
+ case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
+ val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
+ srcFactors.foreach { case (srcBlockId, factors) =>
+ sortedSrcFactors(srcBlockId) = factors
+ }
+ val dstFactors = new Array[Array[Float]](dstIds.size)
+ var j = 0
+ val ls = new NormalEquation(rank)
+ val solver = new CholeskySolver // TODO: add NNLS solver
+ while (j < dstIds.size) {
+ ls.reset()
+ if (implicitPrefs) {
+ ls.merge(YtY.get)
+ }
+ var i = srcPtrs(j)
+ while (i < srcPtrs(j + 1)) {
+ val encoded = srcEncodedIndices(i)
+ val blockId = srcEncoder.blockId(encoded)
+ val localIndex = srcEncoder.localIndex(encoded)
+ val srcFactor = sortedSrcFactors(blockId)(localIndex)
+ val rating = ratings(i)
+ if (implicitPrefs) {
+ ls.addImplicit(srcFactor, rating, alpha)
+ } else {
+ ls.add(srcFactor, rating)
+ }
+ i += 1
+ }
+ dstFactors(j) = solver.solve(ls, regParam)
+ j += 1
+ }
+ dstFactors
+ }
+ }
+
+ /**
+ * Computes the Gramian matrix of user or item factors, which is only used in implicit preference.
+ * Caching of the input factors is handled in [[ALS#train]].
+ */
+ private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
+ factorBlocks.values.aggregate(new NormalEquation(rank))(
+ seqOp = (ne, factors) => {
+ factors.foreach(ne.add(_, 0.0f))
+ ne
+ },
+ combOp = (ne1, ne2) => ne1.merge(ne2))
+ }
+
+ /**
+ * Encoder for storing (blockId, localIndex) into a single integer.
+ *
+ * We use the leading bits (including the sign bit) to store the block id and the rest to store
+ * the local index. This is based on the assumption that users/items are approximately evenly
+ * partitioned. With this assumption, we should be able to encode two billion distinct values.
+ *
+ * @param numBlocks number of blocks
+ */
+ private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable {
+
+ require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.")
+
+ private[this] final val numLocalIndexBits =
+ math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31)
+ private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1
+
+ /** Encodes a (blockId, localIndex) into a single integer. */
+ def encode(blockId: Int, localIndex: Int): Int = {
+ require(blockId < numBlocks)
+ require((localIndex & ~localIndexMask) == 0)
+ (blockId << numLocalIndexBits) | localIndex
+ }
+
+ /** Gets the block id from an encoded index. */
+ @inline
+ def blockId(encoded: Int): Int = {
+ encoded >>> numLocalIndexBits
+ }
+
+ /** Gets the local index from an encoded index. */
+ @inline
+ def localIndex(encoded: Int): Int = {
+ encoded & localIndexMask
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 08fe99176424a..5d51c51346665 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
transformSchema(dataset.schema, paramMap, logging = true)
@@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val epm = map(estimatorParamMaps)
val numModels = epm.size
val metrics = new Array[Double](epm.size)
- val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.applySchema(training, schema).cache()
val validationDataset = sqlCtx.applySchema(validation, schema).cache()
@@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
bestModel.transform(dataset, paramMap)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 6b5c934f015ba..11633e8242313 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -279,45 +279,81 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
- // Initialize each run's center to a random point
+ // Initialize empty centers and point costs.
+ val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
+ var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+
+ // Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
- val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+ val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+
+ /** Merges new centers to centers. */
+ def mergeNewCenters(): Unit = {
+ var r = 0
+ while (r < runs) {
+ centers(r) ++= newCenters(r)
+ newCenters(r).clear()
+ r += 1
+ }
+ }
// On each step, sample 2 * k points on average for each run with probability proportional
- // to their squared distance from that run's current centers
+ // to their squared distance from that run's centers. Note that only distances between points
+ // and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
- val bcCenters = data.context.broadcast(centers)
- val sumCosts = data.flatMap { point =>
- (0 until runs).map { r =>
- (r, KMeans.pointCost(bcCenters.value(r), point))
- }
- }.reduceByKey(_ + _).collectAsMap()
- val chosen = data.mapPartitionsWithIndex { (index, points) =>
+ val bcNewCenters = data.context.broadcast(newCenters)
+ val preCosts = costs
+ costs = data.zip(preCosts).map { case (point, cost) =>
+ Vectors.dense(
+ Array.tabulate(runs) { r =>
+ math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
+ })
+ }.cache()
+ val sumCosts = costs
+ .aggregate(Vectors.zeros(runs))(
+ seqOp = (s, v) => {
+ // s += v
+ axpy(1.0, v, s)
+ s
+ },
+ combOp = (s0, s1) => {
+ // s0 += s1
+ axpy(1.0, s1, s0)
+ s0
+ }
+ )
+ preCosts.unpersist(blocking = false)
+ val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
- points.flatMap { p =>
- (0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
- }.map((_, p))
+ pointsWithCosts.flatMap { case (p, c) =>
+ val rs = (0 until runs).filter { r =>
+ rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
+ }
+ if (rs.length > 0) Some(p, rs) else None
}
}.collect()
- chosen.foreach { case (r, p) =>
- centers(r) += p.toDense
+ mergeNewCenters()
+ chosen.foreach { case (p, rs) =>
+ rs.foreach(newCenters(_) += p.toDense)
}
step += 1
}
+ mergeNewCenters()
+ costs.unpersist(blocking = false)
+
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
- (0 until runs).map { r =>
+ Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
- val finalCenters = (0 until runs).map { r =>
+ val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 3414daccd7ca4..34e0392f1b21a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -257,80 +257,58 @@ private[spark] object BLAS extends Serializable with Logging {
/**
* C := alpha * A * B + beta * C
- * @param transA whether to use the transpose of matrix A (true), or A itself (false).
- * @param transB whether to use the transpose of matrix B (true), or B itself (false).
* @param alpha a scalar to scale the multiplication A * B.
* @param A the matrix A that will be left multiplied to B. Size of m x k.
* @param B the matrix B that will be left multiplied by A. Size of k x n.
* @param beta a scalar that can be used to scale matrix C.
- * @param C the resulting matrix C. Size of m x n.
+ * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false.
*/
def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
+ require(!C.isTransposed,
+ "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
if (alpha == 0.0) {
logDebug("gemm: alpha is equal to 0. Returning C.")
} else {
A match {
case sparse: SparseMatrix =>
- gemm(transA, transB, alpha, sparse, B, beta, C)
+ gemm(alpha, sparse, B, beta, C)
case dense: DenseMatrix =>
- gemm(transA, transB, alpha, dense, B, beta, C)
+ gemm(alpha, dense, B, beta, C)
case _ =>
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
}
}
}
- /**
- * C := alpha * A * B + beta * C
- *
- * @param alpha a scalar to scale the multiplication A * B.
- * @param A the matrix A that will be left multiplied to B. Size of m x k.
- * @param B the matrix B that will be left multiplied by A. Size of k x n.
- * @param beta a scalar that can be used to scale matrix C.
- * @param C the resulting matrix C. Size of m x n.
- */
- def gemm(
- alpha: Double,
- A: Matrix,
- B: DenseMatrix,
- beta: Double,
- C: DenseMatrix): Unit = {
- gemm(false, false, alpha, A, B, beta, C)
- }
-
/**
* C := alpha * A * B + beta * C
* For `DenseMatrix` A.
*/
private def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: DenseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
- val mA: Int = if (!transA) A.numRows else A.numCols
- val nB: Int = if (!transB) B.numCols else B.numRows
- val kA: Int = if (!transA) A.numCols else A.numRows
- val kB: Int = if (!transB) B.numRows else B.numCols
- val tAstr = if (!transA) "N" else "T"
- val tBstr = if (!transB) "N" else "T"
-
- require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
- require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
- require(nB == C.numCols,
- s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
-
- nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows,
- beta, C.values, C.numRows)
+ val tAstr = if (A.isTransposed) "T" else "N"
+ val tBstr = if (B.isTransposed) "T" else "N"
+ val lda = if (!A.isTransposed) A.numRows else A.numCols
+ val ldb = if (!B.isTransposed) B.numRows else B.numCols
+
+ require(A.numCols == B.numRows,
+ s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}")
+ require(A.numRows == C.numRows,
+ s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}")
+ require(B.numCols == C.numCols,
+ s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}")
+
+ nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda,
+ B.values, ldb, beta, C.values, C.numRows)
}
/**
@@ -338,17 +316,15 @@ private[spark] object BLAS extends Serializable with Logging {
* For `SparseMatrix` A.
*/
private def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: SparseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
- val mA: Int = if (!transA) A.numRows else A.numCols
- val nB: Int = if (!transB) B.numCols else B.numRows
- val kA: Int = if (!transA) A.numCols else A.numRows
- val kB: Int = if (!transB) B.numRows else B.numCols
+ val mA: Int = A.numRows
+ val nB: Int = B.numCols
+ val kA: Int = A.numCols
+ val kB: Int = B.numRows
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
@@ -358,23 +334,23 @@ private[spark] object BLAS extends Serializable with Logging {
val Avals = A.values
val Bvals = B.values
val Cvals = C.values
- val Arows = if (!transA) A.rowIndices else A.colPtrs
- val Acols = if (!transA) A.colPtrs else A.rowIndices
+ val ArowIndices = A.rowIndices
+ val AcolPtrs = A.colPtrs
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
- if (transA){
+ if (A.isTransposed){
var colCounterForB = 0
- if (!transB) { // Expensive to put the check inside the loop
+ if (!B.isTransposed) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
val Bstart = colCounterForB * kA
while (rowCounterForA < mA) {
- var i = Arows(rowCounterForA)
- val indEnd = Arows(rowCounterForA + 1)
+ var i = AcolPtrs(rowCounterForA)
+ val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * Bvals(Bstart + Acols(i))
+ sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
@@ -385,19 +361,19 @@ private[spark] object BLAS extends Serializable with Logging {
}
} else {
while (colCounterForB < nB) {
- var rowCounter = 0
+ var rowCounterForA = 0
val Cstart = colCounterForB * mA
- while (rowCounter < mA) {
- var i = Arows(rowCounter)
- val indEnd = Arows(rowCounter + 1)
+ while (rowCounterForA < mA) {
+ var i = AcolPtrs(rowCounterForA)
+ val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * B(colCounterForB, Acols(i))
+ sum += Avals(i) * B(ArowIndices(i), colCounterForB)
i += 1
}
- val Cindex = Cstart + rowCounter
+ val Cindex = Cstart + rowCounterForA
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
- rowCounter += 1
+ rowCounterForA += 1
}
colCounterForB += 1
}
@@ -410,17 +386,17 @@ private[spark] object BLAS extends Serializable with Logging {
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
// B, and added to C.
var colCounterForB = 0 // the column to be updated in C
- if (!transB) { // Expensive to put the check inside the loop
+ if (!B.isTransposed) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Bstart = colCounterForB * kB
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
- var i = Acols(colCounterForA)
- val indEnd = Acols(colCounterForA + 1)
+ var i = AcolPtrs(colCounterForA)
+ val indEnd = AcolPtrs(colCounterForA + 1)
val Bval = Bvals(Bstart + colCounterForA) * alpha
while (i < indEnd) {
- Cvals(Cstart + Arows(i)) += Avals(i) * Bval
+ Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -432,11 +408,11 @@ private[spark] object BLAS extends Serializable with Logging {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
- var i = Acols(colCounterForA)
- val indEnd = Acols(colCounterForA + 1)
- val Bval = B(colCounterForB, colCounterForA) * alpha
+ var i = AcolPtrs(colCounterForA)
+ val indEnd = AcolPtrs(colCounterForA + 1)
+ val Bval = B(colCounterForA, colCounterForB) * alpha
while (i < indEnd) {
- Cvals(Cstart + Arows(i)) += Avals(i) * Bval
+ Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -449,7 +425,6 @@ private[spark] object BLAS extends Serializable with Logging {
/**
* y := alpha * A * x + beta * y
- * @param trans whether to use the transpose of matrix A (true), or A itself (false).
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
@@ -457,65 +432,43 @@ private[spark] object BLAS extends Serializable with Logging {
* @param y the resulting vector y. Size of m x 1.
*/
def gemv(
- trans: Boolean,
alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
-
- val mA: Int = if (!trans) A.numRows else A.numCols
- val nx: Int = x.size
- val nA: Int = if (!trans) A.numCols else A.numRows
-
- require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx")
- require(mA == y.size,
- s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")
+ require(A.numCols == x.size,
+ s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
+ require(A.numRows == y.size,
+ s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}")
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
- gemv(trans, alpha, sparse, x, beta, y)
+ gemv(alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
- gemv(trans, alpha, dense, x, beta, y)
+ gemv(alpha, dense, x, beta, y)
case _ =>
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
}
}
}
- /**
- * y := alpha * A * x + beta * y
- *
- * @param alpha a scalar to scale the multiplication A * x.
- * @param A the matrix A that will be left multiplied to x. Size of m x n.
- * @param x the vector x that will be left multiplied by A. Size of n x 1.
- * @param beta a scalar that can be used to scale vector y.
- * @param y the resulting vector y. Size of m x 1.
- */
- def gemv(
- alpha: Double,
- A: Matrix,
- x: DenseVector,
- beta: Double,
- y: DenseVector): Unit = {
- gemv(false, alpha, A, x, beta, y)
- }
-
/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A.
*/
private def gemv(
- trans: Boolean,
alpha: Double,
A: DenseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
- val tStrA = if (!trans) "N" else "T"
- nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta,
+ val tStrA = if (A.isTransposed) "T" else "N"
+ val mA = if (!A.isTransposed) A.numRows else A.numCols
+ val nA = if (!A.isTransposed) A.numCols else A.numRows
+ nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
@@ -524,24 +477,21 @@ private[spark] object BLAS extends Serializable with Logging {
* For `SparseMatrix` A.
*/
private def gemv(
- trans: Boolean,
alpha: Double,
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
-
val xValues = x.values
val yValues = y.values
-
- val mA: Int = if (!trans) A.numRows else A.numCols
- val nA: Int = if (!trans) A.numCols else A.numRows
+ val mA: Int = A.numRows
+ val nA: Int = A.numCols
val Avals = A.values
- val Arows = if (!trans) A.rowIndices else A.colPtrs
- val Acols = if (!trans) A.colPtrs else A.rowIndices
+ val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
+ val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
- if (trans) {
+ if (A.isTransposed) {
var rowCounter = 0
while (rowCounter < mA) {
var i = Arows(rowCounter)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 5a7281ec6dc3c..ad7e86827b368 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable {
/** Number of columns. */
def numCols: Int
+ /** Flag that keeps track whether the matrix is transposed or not. False by default. */
+ val isTransposed: Boolean = false
+
/** Converts to a dense array in column major. */
- def toArray: Array[Double]
+ def toArray: Array[Double] = {
+ val newArray = new Array[Double](numRows * numCols)
+ foreachActive { (i, j, v) =>
+ newArray(j * numRows + i) = v
+ }
+ newArray
+ }
/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]
@@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable {
/** Get a deep copy of the matrix. */
def copy: Matrix
+ /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */
+ def transpose: Matrix
+
/** Convenience method for `Matrix`-`DenseMatrix` multiplication. */
def multiply(y: DenseMatrix): DenseMatrix = {
- val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix]
- BLAS.gemm(false, false, 1.0, this, y, 0.0, C)
+ val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols)
+ BLAS.gemm(1.0, this, y, 0.0, C)
C
}
@@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable {
output
}
- /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
- private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
- val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
- BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
- C
- }
-
- /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
- private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
- val output = new DenseVector(new Array[Double](numCols))
- BLAS.gemv(true, 1.0, this, y, 0.0, output)
- output
- }
-
/** A human readable representation of the matrix */
override def toString: String = toBreeze.toString()
@@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable {
* backing array. For example, an operation such as addition or subtraction will only be
* performed on the non-zero values in a `SparseMatrix`. */
private[mllib] def update(f: Double => Double): Matrix
+
+ /**
+ * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering
+ * of the elements are not defined.
+ *
+ * @param f the function takes three parameters where the first two parameters are the row
+ * and column indices respectively with the type `Int`, and the final parameter is the
+ * corresponding value in the matrix with type `Double`.
+ */
+ private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}
/**
@@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable {
* @param numRows number of rows
* @param numCols number of columns
* @param values matrix entries in column major
+ * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
+ * row major.
*/
-class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix {
+class DenseMatrix(
+ val numRows: Int,
+ val numCols: Int,
+ val values: Array[Double],
+ override val isTransposed: Boolean) extends Matrix {
require(values.length == numRows * numCols, "The number of values supplied doesn't match the " +
s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}")
- override def toArray: Array[Double] = values
+ /**
+ * Column-major dense matrix.
+ * The entry values are stored in a single array of doubles with columns listed in sequence.
+ * For example, the following matrix
+ * {{{
+ * 1.0 2.0
+ * 3.0 4.0
+ * 5.0 6.0
+ * }}}
+ * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param values matrix entries in column major
+ */
+ def this(numRows: Int, numCols: Int, values: Array[Double]) =
+ this(numRows, numCols, values, false)
override def equals(o: Any) = o match {
case m: DenseMatrix =>
@@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
case _ => false
}
- private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values)
+ private[mllib] def toBreeze: BM[Double] = {
+ if (!isTransposed) {
+ new BDM[Double](numRows, numCols, values)
+ } else {
+ val breezeMatrix = new BDM[Double](numCols, numRows, values)
+ breezeMatrix.t
+ }
+ }
private[mllib] def apply(i: Int): Double = values(i)
private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j))
- private[mllib] def index(i: Int, j: Int): Int = i + numRows * j
+ private[mllib] def index(i: Int, j: Int): Int = {
+ if (!isTransposed) i + numRows * j else j + numCols * i
+ }
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
values(index(i, j)) = v
@@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
this
}
- /** Generate a `SparseMatrix` from the given `DenseMatrix`. */
+ override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed)
+
+ private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = {
+ if (!isTransposed) {
+ // outer loop over columns
+ var j = 0
+ while (j < numCols) {
+ var i = 0
+ val indStart = j * numRows
+ while (i < numRows) {
+ f(i, j, values(indStart + i))
+ i += 1
+ }
+ j += 1
+ }
+ } else {
+ // outer loop over rows
+ var i = 0
+ while (i < numRows) {
+ var j = 0
+ val indStart = i * numCols
+ while (j < numCols) {
+ f(i, j, values(indStart + j))
+ j += 1
+ }
+ i += 1
+ }
+ }
+ }
+
+ /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
+ * set to false. */
def toSparse(): SparseMatrix = {
val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble
val colPtrs: Array[Int] = new Array[Int](numCols + 1)
@@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
var j = 0
while (j < numCols) {
var i = 0
- val indStart = j * numRows
while (i < numRows) {
- val v = values(indStart + i)
+ val v = values(index(i, j))
if (v != 0.0) {
rowIndices += i
spVals += v
@@ -271,49 +340,73 @@ object DenseMatrix {
* @param rowIndices the row index of the entry. They must be in strictly increasing order for each
* column
* @param values non-zero matrix entries in column major
+ * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered
+ * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
+ * and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
class SparseMatrix(
val numRows: Int,
val numCols: Int,
val colPtrs: Array[Int],
val rowIndices: Array[Int],
- val values: Array[Double]) extends Matrix {
+ val values: Array[Double],
+ override val isTransposed: Boolean) extends Matrix {
require(values.length == rowIndices.length, "The number of row indices and values don't match! " +
s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}")
- require(colPtrs.length == numCols + 1, "The length of the column indices should be the " +
- s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " +
- s"numCols: $numCols")
+ // The Or statement is for the case when the matrix is transposed
+ require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " +
+ "column indices should be the number of columns + 1. Currently, colPointers.length: " +
+ s"${colPtrs.length}, numCols: $numCols")
require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " +
s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}")
- override def toArray: Array[Double] = {
- val arr = new Array[Double](numRows * numCols)
- var j = 0
- while (j < numCols) {
- var i = colPtrs(j)
- val indEnd = colPtrs(j + 1)
- val offset = j * numRows
- while (i < indEnd) {
- val rowIndex = rowIndices(i)
- arr(offset + rowIndex) = values(i)
- i += 1
- }
- j += 1
- }
- arr
+ /**
+ * Column-major sparse matrix.
+ * The entry values are stored in Compressed Sparse Column (CSC) format.
+ * For example, the following matrix
+ * {{{
+ * 1.0 0.0 4.0
+ * 0.0 3.0 5.0
+ * 2.0 0.0 6.0
+ * }}}
+ * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`,
+ * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param colPtrs the index corresponding to the start of a new column
+ * @param rowIndices the row index of the entry. They must be in strictly increasing
+ * order for each column
+ * @param values non-zero matrix entries in column major
+ */
+ def this(
+ numRows: Int,
+ numCols: Int,
+ colPtrs: Array[Int],
+ rowIndices: Array[Int],
+ values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
+
+ private[mllib] def toBreeze: BM[Double] = {
+ if (!isTransposed) {
+ new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
+ } else {
+ val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices)
+ breezeMatrix.t
+ }
}
- private[mllib] def toBreeze: BM[Double] =
- new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
-
private[mllib] def apply(i: Int, j: Int): Double = {
val ind = index(i, j)
if (ind < 0) 0.0 else values(ind)
}
private[mllib] def index(i: Int, j: Int): Int = {
- Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i)
+ if (!isTransposed) {
+ Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i)
+ } else {
+ Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j)
+ }
}
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
@@ -322,7 +415,7 @@ class SparseMatrix(
throw new NoSuchElementException("The given row and column indices correspond to a zero " +
"value. Only non-zero elements in Sparse Matrices can be updated.")
} else {
- values(index(i, j)) = v
+ values(ind) = v
}
}
@@ -341,7 +434,38 @@ class SparseMatrix(
this
}
- /** Generate a `DenseMatrix` from the given `SparseMatrix`. */
+ override def transpose: Matrix =
+ new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed)
+
+ private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = {
+ if (!isTransposed) {
+ var j = 0
+ while (j < numCols) {
+ var idx = colPtrs(j)
+ val idxEnd = colPtrs(j + 1)
+ while (idx < idxEnd) {
+ f(rowIndices(idx), j, values(idx))
+ idx += 1
+ }
+ j += 1
+ }
+ } else {
+ var i = 0
+ while (i < numRows) {
+ var idx = colPtrs(i)
+ val idxEnd = colPtrs(i + 1)
+ while (idx < idxEnd) {
+ val j = rowIndices(idx)
+ f(i, j, values(idx))
+ idx += 1
+ }
+ i += 1
+ }
+ }
+ }
+
+ /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed
+ * set to false. */
def toDense(): DenseMatrix = {
new DenseMatrix(numRows, numCols, toArray)
}
@@ -557,10 +681,9 @@ object Matrices {
private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = {
breeze match {
case dm: BDM[Double] =>
- require(dm.majorStride == dm.rows,
- "Do not support stride size different from the number of rows.")
- new DenseMatrix(dm.rows, dm.cols, dm.data)
+ new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose)
case sm: BSM[Double] =>
+ // There is no isTranspose flag for sparse matrices in Breeze
new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
case _ =>
throw new UnsupportedOperationException(
@@ -679,46 +802,28 @@ object Matrices {
new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray))
} else {
var startCol = 0
- val entries: Array[(Int, Int, Double)] = matrices.flatMap {
- case spMat: SparseMatrix =>
- var j = 0
- val colPtrs = spMat.colPtrs
- val rowIndices = spMat.rowIndices
- val values = spMat.values
- val data = new Array[(Int, Int, Double)](values.length)
- val nCols = spMat.numCols
- while (j < nCols) {
- var idx = colPtrs(j)
- while (idx < colPtrs(j + 1)) {
- val i = rowIndices(idx)
- val v = values(idx)
- data(idx) = (i, j + startCol, v)
- idx += 1
+ val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat =>
+ val nCols = mat.numCols
+ mat match {
+ case spMat: SparseMatrix =>
+ val data = new Array[(Int, Int, Double)](spMat.values.length)
+ var cnt = 0
+ spMat.foreachActive { (i, j, v) =>
+ data(cnt) = (i, j + startCol, v)
+ cnt += 1
}
- j += 1
- }
- startCol += nCols
- data
- case dnMat: DenseMatrix =>
- val data = new ArrayBuffer[(Int, Int, Double)]()
- var j = 0
- val nCols = dnMat.numCols
- val nRows = dnMat.numRows
- val values = dnMat.values
- while (j < nCols) {
- var i = 0
- val indStart = j * nRows
- while (i < nRows) {
- val v = values(indStart + i)
+ startCol += nCols
+ data
+ case dnMat: DenseMatrix =>
+ val data = new ArrayBuffer[(Int, Int, Double)]()
+ dnMat.foreachActive { (i, j, v) =>
if (v != 0.0) {
data.append((i, j + startCol, v))
}
- i += 1
}
- j += 1
- }
- startCol += nCols
- data
+ startCol += nCols
+ data
+ }
}
SparseMatrix.fromCOO(numRows, numCols, entries)
}
@@ -744,14 +849,12 @@ object Matrices {
require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " +
"don't match!")
mat match {
- case sparse: SparseMatrix =>
- hasSparse = true
- case dense: DenseMatrix =>
+ case sparse: SparseMatrix => hasSparse = true
+ case dense: DenseMatrix => // empty on purpose
case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " +
s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}")
}
numRows += mat.numRows
-
}
if (!hasSparse) {
val allValues = new Array[Double](numRows * numCols)
@@ -759,61 +862,37 @@ object Matrices {
matrices.foreach { mat =>
var j = 0
val nRows = mat.numRows
- val values = mat.toArray
- while (j < numCols) {
- var i = 0
+ mat.foreachActive { (i, j, v) =>
val indStart = j * numRows + startRow
- val subMatStart = j * nRows
- while (i < nRows) {
- allValues(indStart + i) = values(subMatStart + i)
- i += 1
- }
- j += 1
+ allValues(indStart + i) = v
}
startRow += nRows
}
new DenseMatrix(numRows, numCols, allValues)
} else {
var startRow = 0
- val entries: Array[(Int, Int, Double)] = matrices.flatMap {
- case spMat: SparseMatrix =>
- var j = 0
- val colPtrs = spMat.colPtrs
- val rowIndices = spMat.rowIndices
- val values = spMat.values
- val data = new Array[(Int, Int, Double)](values.length)
- while (j < numCols) {
- var idx = colPtrs(j)
- while (idx < colPtrs(j + 1)) {
- val i = rowIndices(idx)
- val v = values(idx)
- data(idx) = (i + startRow, j, v)
- idx += 1
+ val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat =>
+ val nRows = mat.numRows
+ mat match {
+ case spMat: SparseMatrix =>
+ val data = new Array[(Int, Int, Double)](spMat.values.length)
+ var cnt = 0
+ spMat.foreachActive { (i, j, v) =>
+ data(cnt) = (i + startRow, j, v)
+ cnt += 1
}
- j += 1
- }
- startRow += spMat.numRows
- data
- case dnMat: DenseMatrix =>
- val data = new ArrayBuffer[(Int, Int, Double)]()
- var j = 0
- val nCols = dnMat.numCols
- val nRows = dnMat.numRows
- val values = dnMat.values
- while (j < nCols) {
- var i = 0
- val indStart = j * nRows
- while (i < nRows) {
- val v = values(indStart + i)
+ startRow += nRows
+ data
+ case dnMat: DenseMatrix =>
+ val data = new ArrayBuffer[(Int, Int, Double)]()
+ dnMat.foreachActive { (i, j, v) =>
if (v != 0.0) {
data.append((i + startRow, j, v))
}
- i += 1
}
- j += 1
- }
- startRow += nRows
- data
+ startRow += nRows
+ data
+ }
}
SparseMatrix.fromCOO(numRows, numCols, entries)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 7ee0224ad4662..2834ea75ceb8f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -333,7 +333,7 @@ object Vectors {
math.pow(sum, 1.0 / p)
}
}
-
+
/**
* Returns the squared distance between two Vectors.
* @param v1 first Vector.
@@ -341,8 +341,9 @@ object Vectors {
* @return squared distance between two Vectors.
*/
def sqdist(v1: Vector, v2: Vector): Double = {
+ require(v1.size == v2.size, "vector dimension mismatch")
var squaredDistance = 0.0
- (v1, v2) match {
+ (v1, v2) match {
case (v1: SparseVector, v2: SparseVector) =>
val v1Values = v1.values
val v1Indices = v1.indices
@@ -350,12 +351,12 @@ object Vectors {
val v2Indices = v2.indices
val nnzv1 = v1Indices.size
val nnzv2 = v2Indices.size
-
+
var kv1 = 0
var kv2 = 0
while (kv1 < nnzv1 || kv2 < nnzv2) {
var score = 0.0
-
+
if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
score = v1Values(kv1)
kv1 += 1
@@ -370,18 +371,23 @@ object Vectors {
squaredDistance += score * score
}
- case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
+ case (v1: SparseVector, v2: DenseVector) =>
squaredDistance = sqdist(v1, v2)
- case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 =>
+ case (v1: DenseVector, v2: SparseVector) =>
squaredDistance = sqdist(v2, v1)
- // When a SparseVector is approximately dense, we treat it as a DenseVector
- case (v1, v2) =>
- squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) =>
- val score = elems._1 - elems._2
- distance + score * score
+ case (DenseVector(vv1), DenseVector(vv2)) =>
+ var kv = 0
+ val sz = vv1.size
+ while (kv < sz) {
+ val score = vv1(kv) - vv2(kv)
+ squaredDistance += score * score
+ kv += 1
}
+ case _ =>
+ throw new IllegalArgumentException("Do not support vector type " + v1.getClass +
+ " and " + v2.getClass)
}
squaredDistance
}
@@ -397,7 +403,7 @@ object Vectors {
val nnzv1 = indices.size
val nnzv2 = v2.size
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
-
+
while (kv2 < nnzv2) {
var score = 0.0
if (kv2 != iv1) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index bee951a2e5e26..5f84677be238d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double)
*
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
- * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
* indicated user
* preferences rather than explicit ratings given to items.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index e9304b5e5c650..482dd4b272d1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -140,6 +140,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+ logDebug("subsamplingRate = " + strategy.subsamplingRate)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
@@ -155,19 +156,12 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val (subsample, withReplacement) = {
- // TODO: Have a stricter check for RF in the strategy
- val isRandomForest = numTrees > 1
- if (isRandomForest) {
- (1.0, true)
- } else {
- (strategy.subsamplingRate, false)
- }
- }
+ val withReplacement = if (numTrees > 1) true else false
val baggedInput
- = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
- .persist(StorageLevel.MEMORY_AND_DISK)
+ = BaggedPoint.convertToBaggedRDD(treeInput,
+ strategy.subsamplingRate, numTrees,
+ withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index cf51d041c65a9..ed8e6a796f8c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -68,6 +68,15 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported: "Classification" or "Regression"
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ defaultParams(Algo.fromString(algo))
+ }
+
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
@@ -75,15 +84,15 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy(algo)
- treeStrategy.maxDepth = 3
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStragtegy = Strategy.defaultStategy(algo)
+ treeStragtegy.maxDepth = 3
algo match {
- case "Classification" =>
- treeStrategy.numClasses = 2
- new BoostingStrategy(treeStrategy, LogLoss)
- case "Regression" =>
- new BoostingStrategy(treeStrategy, SquaredError)
+ case Algo.Classification =>
+ treeStragtegy.numClasses = 2
+ new BoostingStrategy(treeStragtegy, LogLoss)
+ case Algo.Regression =>
+ new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d5cd89ab94e81..3308adb6752ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -156,6 +156,9 @@ class Strategy (
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
+ require(subsamplingRate > 0 && subsamplingRate <= 1,
+ s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
+ s"$subsamplingRate")
}
/** Returns a shallow copy of this instance. */
@@ -173,11 +176,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
- def defaultStrategy(algo: String): Strategy = algo match {
- case "Classification" =>
+ def defaultStrategy(algo: String): Strategy = {
+ defaultStategy(Algo.fromString(algo))
+ }
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo Algo.Classification or Algo.Regression
+ */
+ def defaultStategy(algo: Algo): Strategy = algo match {
+ case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
- case "Regression" =>
+ case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 0e02345aa3774..b7950e00786ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
val lbl = label.toInt
require(lbl < stats.length,
s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "Entropy does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 7c83cd48e16a0..c946db9c0d1c8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
val lbl = label.toInt
require(lbl < stats.length,
s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "GiniImpurity does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 47f1f46c6c260..56a9dbdd58b64 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -26,7 +26,7 @@
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -37,7 +37,7 @@ public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
@@ -65,7 +65,7 @@ public void pipeline() {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 2eba83335bb58..f4ba23c44563e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -26,7 +26,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
@@ -55,7 +55,7 @@ public void logisticRegression() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
@@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() {
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index a9f1c4a2c3ca7..074b58c07df7a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,7 +30,7 @@
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 4515084bc7ae9..2f175fb117941 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
class PipelineSuite extends FunSuite {
@@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite {
val estimator2 = mock[Estimator[MyModel]]
val model2 = mock[MyModel]
val transformer3 = mock[Transformer]
- val dataset0 = mock[SchemaRDD]
- val dataset1 = mock[SchemaRDD]
- val dataset2 = mock[SchemaRDD]
- val dataset3 = mock[SchemaRDD]
- val dataset4 = mock[SchemaRDD]
+ val dataset0 = mock[DataFrame]
+ val dataset1 = mock[DataFrame]
+ val dataset2 = mock[DataFrame]
+ val dataset3 = mock[DataFrame]
+ val dataset4 = mock[DataFrame]
when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
@@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite {
val estimator = mock[Estimator[MyModel]]
val pipeline = new Pipeline()
.setStages(Array(estimator, estimator))
- val dataset = mock[SchemaRDD]
+ val dataset = mock[DataFrame]
intercept[IllegalArgumentException] {
pipeline.fit(dataset)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e8030fef55b1d..1912afce93b18 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -21,12 +21,12 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
test("logistic regression") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
- .select('label, 'prediction)
+ .select("label", "prediction")
.collect()
}
test("logistic regression with setters") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
val model = lr.fit(dataset)
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select('label, 'score, 'prediction)
+ .select("label", "score", "prediction")
.collect()
}
test("logistic regression fit and transform with varargs") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select('label, 'probability, 'prediction)
+ .select("label", "probability", "prediction")
.collect()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
new file mode 100644
index 0000000000000..58289acdbc095
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.util.Random
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.scalatest.FunSuite
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.recommendation.ALS._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+
+class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("LocalIndexEncoder") {
+ val random = new Random
+ for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
+ val encoder = new LocalIndexEncoder(numBlocks)
+ val maxLocalIndex = Int.MaxValue / numBlocks
+ val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++
+ Seq((0, 0), (numBlocks - 1, maxLocalIndex))
+ tests.foreach { case (blockId, localIndex) =>
+ val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex."
+ val encoded = encoder.encode(blockId, localIndex)
+ assert(encoder.blockId(encoded) === blockId, err)
+ assert(encoder.localIndex(encoded) === localIndex, err)
+ }
+ }
+ }
+
+ test("normal equation construction with explict feedback") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 3.0f)
+ .add(Array(4.0f, 5.0f), 6.0f)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 2)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5")
+ // b = np.matrix("3; 6")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
+
+ val ne1 = new NormalEquation(2)
+ .add(Array(7.0f, 8.0f), 9.0f)
+ ne0.merge(ne1)
+ assert(ne0.n === 3)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5; 7 8")
+ // b = np.matrix("3; 6; 9")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
+
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f), 2.0f)
+ }
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
+ }
+ intercept[IllegalArgumentException] {
+ val ne2 = new NormalEquation(3)
+ ne0.merge(ne2)
+ }
+
+ ne0.reset()
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+ }
+
+ test("normal equation construction with implicit feedback") {
+ val k = 2
+ val alpha = 0.5
+ val ne0 = new NormalEquation(k)
+ .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
+ .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
+ .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 0) // addImplicit doesn't increase the count.
+ // NumPy code that computes the expected values:
+ // alpha = 0.5
+ // A = np.matrix("-5 -4; -2 -1; 1 2")
+ // b = np.matrix("-3; 0; 3")
+ // b1 = b > 0
+ // c = 1.0 + alpha * np.abs(b)
+ // C = np.diag(c.A1)
+ // I = np.eye(3)
+ // ata = A.transpose() * (C - I) * A
+ // atb = A.transpose() * C * b1
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
+ }
+
+ test("CholeskySolver") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 4.0f)
+ .add(Array(1.0f, 3.0f), 9.0f)
+ .add(Array(1.0f, 4.0f), 16.0f)
+ val ne1 = new NormalEquation(k)
+ .merge(ne0)
+
+ val chol = new CholeskySolver
+ val x0 = chol.solve(ne0, 0.0).map(_.toDouble)
+ // NumPy code that computes the expected solution:
+ // A = np.matrix("1 2; 1 3; 1 4")
+ // b = b = np.matrix("3; 6")
+ // x0 = np.linalg.lstsq(A, b)[0]
+ assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
+
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+
+ val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
+ // NumPy code that computes the expected solution, where lambda is scaled by n:
+ // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
+ assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
+ }
+
+ test("RatingBlockBuilder") {
+ val emptyBuilder = new RatingBlockBuilder()
+ assert(emptyBuilder.size === 0)
+ val emptyBlock = emptyBuilder.build()
+ assert(emptyBlock.srcIds.isEmpty)
+ assert(emptyBlock.dstIds.isEmpty)
+ assert(emptyBlock.ratings.isEmpty)
+
+ val builder0 = new RatingBlockBuilder()
+ .add(Rating(0, 1, 2.0f))
+ .add(Rating(3, 4, 5.0f))
+ assert(builder0.size === 2)
+ val builder1 = new RatingBlockBuilder()
+ .add(Rating(6, 7, 8.0f))
+ .merge(builder0.build())
+ assert(builder1.size === 3)
+ val block = builder1.build()
+ val ratings = Seq.tabulate(block.size) { i =>
+ (block.srcIds(i), block.dstIds(i), block.ratings(i))
+ }.toSet
+ assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f)))
+ }
+
+ test("UncompressedInBlock") {
+ val encoder = new LocalIndexEncoder(10)
+ val uncompressed = new UncompressedInBlockBuilder(encoder)
+ .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
+ .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
+ .build()
+ assert(uncompressed.size === 5)
+ val records = Seq.tabulate(uncompressed.size) { i =>
+ val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i))
+ }.toSet
+ val expected =
+ Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f))
+ assert(records === expected)
+
+ val compressed = uncompressed.compress()
+ assert(compressed.size === 5)
+ assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3))
+ assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
+ var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
+ var i = 0
+ while (i < compressed.srcIds.size) {
+ var j = compressed.dstPtrs(i)
+ while (j < compressed.dstPtrs(i + 1)) {
+ val dstEncodedIndex = compressed.dstEncodedIndices(j)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j)))
+ j += 1
+ }
+ i += 1
+ }
+ assert(decompressed.toSet === expected)
+ }
+
+ /**
+ * Generates an explicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genExplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates an implicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genImplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ // The assumption of the implicit feedback model is that unobserved ratings are more likely to
+ // be negatives.
+ val positiveFraction = 0.8
+ val negativeFraction = 1.0 - positiveFraction
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ val threshold = if (rating > 0) positiveFraction else negativeFraction
+ val observed = random.nextDouble() < threshold
+ if (observed) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ }
+ logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
+ * @param size number of users/items
+ * @param rank number of features
+ * @param random random number generator
+ * @param a min value of the support (default: -1)
+ * @param b max value of the support (default: 1)
+ * @return a sequence of (ID, factors) pairs
+ */
+ private def genFactors(
+ size: Int,
+ rank: Int,
+ random: Random,
+ a: Float = -1.0f,
+ b: Float = 1.0f): Seq[(Int, Array[Float])] = {
+ require(size > 0 && size < Int.MaxValue / 3)
+ require(b > a)
+ val ids = mutable.Set.empty[Int]
+ while (ids.size < size) {
+ ids += random.nextInt()
+ }
+ val width = b - a
+ ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width)))
+ }
+
+ /**
+ * Test ALS using the given training/test splits and parameters.
+ * @param training training dataset
+ * @param test test dataset
+ * @param rank rank of the matrix factorization
+ * @param maxIter max number of iterations
+ * @param regParam regularization constant
+ * @param implicitPrefs whether to use implicit preference
+ * @param numUserBlocks number of user blocks
+ * @param numItemBlocks number of item blocks
+ * @param targetRMSE target test RMSE
+ */
+ def testALS(
+ training: RDD[Rating],
+ test: RDD[Rating],
+ rank: Int,
+ maxIter: Int,
+ regParam: Double,
+ implicitPrefs: Boolean = false,
+ numUserBlocks: Int = 2,
+ numItemBlocks: Int = 3,
+ targetRMSE: Double = 0.05): Unit = {
+ val sqlContext = this.sqlContext
+ import sqlContext.createSchemaRDD
+ val als = new ALS()
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setImplicitPrefs(implicitPrefs)
+ .setNumUserBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ val alpha = als.getAlpha
+ val model = als.fit(training)
+ val predictions = model.transform(test)
+ .select("rating", "prediction")
+ .map { case Row(rating: Float, prediction: Float) =>
+ (rating.toDouble, prediction.toDouble)
+ }
+ val rmse =
+ if (implicitPrefs) {
+ // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
+ // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE
+ // with the confidence scores as weights.
+ val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
+ val confidence = 1.0 + alpha * math.abs(rating)
+ val rating01 = math.max(math.min(rating, 1.0), 0.0)
+ val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
+ val err = prediction01 - rating01
+ (confidence, confidence * err * err)
+ }.reduce { case ((c0, e0), (c1, e1)) =>
+ (c0 + c1, e0 + e1)
+ }
+ math.sqrt(weightedSumSq / totalWeight)
+ } else {
+ val mse = predictions.map { case (rating, prediction) =>
+ val err = rating - prediction
+ err * err
+ }.mean()
+ math.sqrt(mse)
+ }
+ logInfo(s"Test RMSE is $rmse.")
+ assert(rmse < targetRMSE)
+ }
+
+ test("exact rank-1 matrix") {
+ val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1)
+ testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
+ testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
+ }
+
+ test("approximate rank-1 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
+ testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
+ }
+
+ test("approximate rank-2 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
+ testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
+ }
+
+ test("different block settings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
+ numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
+ }
+ }
+
+ test("more blocks than ratings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
+ numItemBlocks = 5, numUserBlocks = 5)
+ }
+
+ test("implicit feedback") {
+ val (training, test) =
+ genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
+ targetRMSE = 0.3)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 41cc13da4d5b1..74104fa7a681a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 771878e925ea7..b0b78acd6df16 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -169,16 +169,17 @@ class BLASSuite extends FunSuite {
}
test("gemm") {
-
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0))
val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0))
+ val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0))
+ val BT = B.transpose
- assert(dA multiply B ~== expected absTol 1e-15)
- assert(sA multiply B ~== expected absTol 1e-15)
+ assert(dA.multiply(B) ~== expected absTol 1e-15)
+ assert(sA.multiply(B) ~== expected absTol 1e-15)
val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0))
val C2 = C1.copy
@@ -188,6 +189,10 @@ class BLASSuite extends FunSuite {
val C6 = C1.copy
val C7 = C1.copy
val C8 = C1.copy
+ val C9 = C1.copy
+ val C10 = C1.copy
+ val C11 = C1.copy
+ val C12 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
@@ -202,26 +207,40 @@ class BLASSuite extends FunSuite {
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
- gemm(true, false, 1.0, dA, B, 2.0, C1)
+ gemm(1.0, dA.transpose, B, 2.0, C1)
}
}
- val dAT =
+ val dATman =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
- val sAT =
+ val sATman =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
- assert(dAT transposeMultiply B ~== expected absTol 1e-15)
- assert(sAT transposeMultiply B ~== expected absTol 1e-15)
-
- gemm(true, false, 1.0, dAT, B, 2.0, C5)
- gemm(true, false, 1.0, sAT, B, 2.0, C6)
- gemm(true, false, 2.0, dAT, B, 2.0, C7)
- gemm(true, false, 2.0, sAT, B, 2.0, C8)
+ val dATT = dATman.transpose
+ val sATT = sATman.transpose
+ val BTT = BTman.transpose.asInstanceOf[DenseMatrix]
+
+ assert(dATT.multiply(B) ~== expected absTol 1e-15)
+ assert(sATT.multiply(B) ~== expected absTol 1e-15)
+ assert(dATT.multiply(BTT) ~== expected absTol 1e-15)
+ assert(sATT.multiply(BTT) ~== expected absTol 1e-15)
+
+ gemm(1.0, dATT, BTT, 2.0, C5)
+ gemm(1.0, sATT, BTT, 2.0, C6)
+ gemm(2.0, dATT, BTT, 2.0, C7)
+ gemm(2.0, sATT, BTT, 2.0, C8)
+ gemm(1.0, dA, BTT, 2.0, C9)
+ gemm(1.0, sA, BTT, 2.0, C10)
+ gemm(2.0, dA, BTT, 2.0, C11)
+ gemm(2.0, sA, BTT, 2.0, C12)
assert(C5 ~== expected2 absTol 1e-15)
assert(C6 ~== expected2 absTol 1e-15)
assert(C7 ~== expected3 absTol 1e-15)
assert(C8 ~== expected3 absTol 1e-15)
+ assert(C9 ~== expected2 absTol 1e-15)
+ assert(C10 ~== expected2 absTol 1e-15)
+ assert(C11 ~== expected3 absTol 1e-15)
+ assert(C12 ~== expected3 absTol 1e-15)
}
test("gemv") {
@@ -233,17 +252,13 @@ class BLASSuite extends FunSuite {
val x = new DenseVector(Array(1.0, 2.0, 3.0))
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
- assert(dA multiply x ~== expected absTol 1e-15)
- assert(sA multiply x ~== expected absTol 1e-15)
+ assert(dA.multiply(x) ~== expected absTol 1e-15)
+ assert(sA.multiply(x) ~== expected absTol 1e-15)
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
val y4 = y1.copy
- val y5 = y1.copy
- val y6 = y1.copy
- val y7 = y1.copy
- val y8 = y1.copy
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
@@ -257,25 +272,18 @@ class BLASSuite extends FunSuite {
assert(y4 ~== expected3 absTol 1e-15)
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
- gemv(true, 1.0, dA, x, 2.0, y1)
+ gemv(1.0, dA.transpose, x, 2.0, y1)
}
}
-
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
- assert(dAT transposeMultiply x ~== expected absTol 1e-15)
- assert(sAT transposeMultiply x ~== expected absTol 1e-15)
-
- gemv(true, 1.0, dAT, x, 2.0, y5)
- gemv(true, 1.0, sAT, x, 2.0, y6)
- gemv(true, 2.0, dAT, x, 2.0, y7)
- gemv(true, 2.0, sAT, x, 2.0, y8)
- assert(y5 ~== expected2 absTol 1e-15)
- assert(y6 ~== expected2 absTol 1e-15)
- assert(y7 ~== expected3 absTol 1e-15)
- assert(y8 ~== expected3 absTol 1e-15)
+ val dATT = dAT.transpose
+ val sATT = sAT.transpose
+
+ assert(dATT.multiply(x) ~== expected absTol 1e-15)
+ assert(sATT.multiply(x) ~== expected absTol 1e-15)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
index 73a6d3a27d868..2031032373971 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
@@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite {
assert(mat.numRows === breeze.rows)
assert(mat.numCols === breeze.cols)
assert(mat.values.eq(breeze.data), "should not copy data")
+ // transposed matrix
+ val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix]
+ assert(matTransposed.numRows === breeze.cols)
+ assert(matTransposed.numCols === breeze.rows)
+ assert(matTransposed.values.eq(breeze.data), "should not copy data")
}
test("sparse matrix to breeze") {
@@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite {
assert(mat.numRows === breeze.rows)
assert(mat.numCols === breeze.cols)
assert(mat.values.eq(breeze.data), "should not copy data")
+ val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix]
+ assert(matTransposed.numRows === breeze.cols)
+ assert(matTransposed.numCols === breeze.rows)
+ assert(!matTransposed.values.eq(breeze.data), "has to copy data")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index a35d0fe389fdd..b1ebfde0e5e57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -22,6 +22,9 @@ import java.util.Random
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._
+import scala.collection.mutable.{Map => MutableMap}
+
+import org.apache.spark.mllib.util.TestingUtils._
class MatricesSuite extends FunSuite {
test("dense matrix construction") {
@@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite {
assert(mat.numRows === m)
assert(mat.numCols === n)
assert(mat.values.eq(values), "should not copy data")
- assert(mat.toArray.eq(values), "toArray should not copy data")
}
test("dense matrix construction with wrong dimension") {
@@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite {
assert(deMat1.toArray === deMat2.toArray)
}
+ test("transpose") {
+ val dA =
+ new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
+ val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
+
+ val dAT = dA.transpose.asInstanceOf[DenseMatrix]
+ val sAT = sA.transpose.asInstanceOf[SparseMatrix]
+ val dATexpected =
+ new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
+ val sATexpected =
+ new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
+
+ assert(dAT.toBreeze === dATexpected.toBreeze)
+ assert(sAT.toBreeze === sATexpected.toBreeze)
+ assert(dA(1, 0) === dAT(0, 1))
+ assert(dA(2, 1) === dAT(1, 2))
+ assert(sA(1, 0) === sAT(0, 1))
+ assert(sA(2, 1) === sAT(1, 2))
+
+ assert(!dA.toArray.eq(dAT.toArray), "has to have a new array")
+ assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array")
+
+ assert(dAT.toSparse().toBreeze === sATexpected.toBreeze)
+ assert(sAT.toDense().toBreeze === dATexpected.toBreeze)
+ }
+
+ test("foreachActive") {
+ val m = 3
+ val n = 2
+ val values = Array(1.0, 2.0, 4.0, 5.0)
+ val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0)
+ val colPtrs = Array(0, 2, 4)
+ val rowIndices = Array(0, 1, 1, 2)
+
+ val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values)
+ val dn = new DenseMatrix(m, n, allValues)
+
+ val dnMap = MutableMap[(Int, Int), Double]()
+ dn.foreachActive { (i, j, value) =>
+ dnMap.put((i, j), value)
+ }
+ assert(dnMap.size === 6)
+ assert(dnMap(0, 0) === 1.0)
+ assert(dnMap(1, 0) === 2.0)
+ assert(dnMap(2, 0) === 0.0)
+ assert(dnMap(0, 1) === 0.0)
+ assert(dnMap(1, 1) === 4.0)
+ assert(dnMap(2, 1) === 5.0)
+
+ val spMap = MutableMap[(Int, Int), Double]()
+ sp.foreachActive { (i, j, value) =>
+ spMap.put((i, j), value)
+ }
+ assert(spMap.size === 4)
+ assert(spMap(0, 0) === 1.0)
+ assert(spMap(1, 0) === 2.0)
+ assert(spMap(1, 1) === 4.0)
+ assert(spMap(2, 1) === 5.0)
+ }
+
test("horzcat, vertcat, eye, speye") {
val m = 3
val n = 2
@@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite {
val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0)
val colPtrs = Array(0, 2, 4)
val rowIndices = Array(0, 1, 1, 2)
+ // transposed versions
+ val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0)
+ val colPtrsT = Array(0, 1, 3, 4)
+ val rowIndicesT = Array(0, 0, 1, 1)
val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values)
val deMat1 = new DenseMatrix(m, n, allValues)
+ val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values)
+ val deMat1T = new DenseMatrix(n, m, allValuesT)
+
+ // should equal spMat1 & deMat1 respectively
+ val spMat1TT = spMat1T.transpose
+ val deMat1TT = deMat1T.transpose
+
val deMat2 = Matrices.eye(3)
val spMat2 = Matrices.speye(3)
val deMat3 = Matrices.eye(2)
@@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite {
val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2))
val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2))
-
val deHorz2 = Matrices.horzcat(Array[Matrix]())
assert(deHorz1.numRows === 3)
@@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite {
assert(deHorz2.numCols === 0)
assert(deHorz2.toArray.length === 0)
- assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix)
- assert(spHorz2.toBreeze === spHorz3.toBreeze)
+ assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15)
+ assert(spHorz2 ~== spHorz3 absTol 1e-15)
assert(spHorz(0, 0) === 1.0)
assert(spHorz(2, 1) === 5.0)
assert(spHorz(0, 2) === 1.0)
@@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite {
assert(deHorz1(2, 4) === 1.0)
assert(deHorz1(1, 4) === 0.0)
+ // containing transposed matrices
+ val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2))
+ val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2))
+ val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2))
+ val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2))
+
+ assert(deHorz1T ~== deHorz1 absTol 1e-15)
+ assert(spHorzT ~== spHorz absTol 1e-15)
+ assert(spHorz2T ~== spHorz2 absTol 1e-15)
+ assert(spHorz3T ~== spHorz3 absTol 1e-15)
+
intercept[IllegalArgumentException] {
Matrices.horzcat(Array(spMat1, spMat3))
}
@@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite {
assert(deVert2.numCols === 0)
assert(deVert2.toArray.length === 0)
- assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix)
- assert(spVert2.toBreeze === spVert3.toBreeze)
+ assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15)
+ assert(spVert2 ~== spVert3 absTol 1e-15)
assert(spVert(0, 0) === 1.0)
assert(spVert(2, 1) === 5.0)
assert(spVert(3, 0) === 1.0)
@@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite {
assert(deVert1(3, 1) === 0.0)
assert(deVert1(4, 1) === 1.0)
+ // containing transposed matrices
+ val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3))
+ val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3))
+ val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3))
+ val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3))
+
+ assert(deVert1T ~== deVert1 absTol 1e-15)
+ assert(spVertT ~== spVert absTol 1e-15)
+ assert(spVert2T ~== spVert2 absTol 1e-15)
+ assert(spVert3T ~== spVert3 absTol 1e-15)
+
intercept[IllegalArgumentException] {
Matrices.vertcat(Array(spMat1, spMat2))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index f3b7bfda788fa..e9fc37e000526 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
* @param implicitPrefs flag to test implicit feedback
- * @param bulkPredict flag to test bulk prediciton
+ * @param bulkPredict flag to test bulk predicition
* @param negativeWeights whether the generated data can contain negative values
* @param numUserBlocks number of user blocks to partition users into
* @param numProductBlocks number of product blocks to partition products into
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
new file mode 100644
index 0000000000000..92b498580af03
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
+ */
+class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+ test("Gini impurity does not support negative labels") {
+ val gini = new GiniAggregator(2)
+ intercept[IllegalArgumentException] {
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+
+ test("Entropy does not support negative labels") {
+ val entropy = new EntropyAggregator(2)
+ intercept[IllegalArgumentException] {
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index f7f0f20c6c125..55e963977b54f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
+
+ test("subsampling rate in RandomForest"){
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
+ useNodeIdCache = true)
+
+ val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ strategy.subsamplingRate = 0.5
+ val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf1.toDebugString != rf2.toDebugString)
+ }
+
}
diff --git a/pom.xml b/pom.xml
index b993391b15042..05cb3797fc55b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -117,7 +117,7 @@
2.0.1
0.21.0
shaded-protobuf
- 1.7.5
+ 1.7.10
1.2.17
1.0.4
2.4.1
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 127973b658190..e750fed7448cd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,6 +52,20 @@ object MimaExcludes {
"org.apache.spark.mllib.linalg.Matrices.randn"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrices.rand")
+ ) ++ Seq(
+ // SPARK-5321
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.transpose"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." +
+ "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.foreachActive")
) ++ Seq(
// SPARK-3325
ProblemFilters.exclude[MissingMethodProblem](
@@ -81,7 +95,20 @@ object MimaExcludes {
) ++ Seq(
// SPARK-5166 Spark SQL API stabilization
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit")
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate")
) ++ Seq(
// SPARK-5270
ProblemFilters.exclude[MissingMethodProblem](
@@ -90,6 +117,10 @@ object MimaExcludes {
// SPARK-5297 Java FileStream do not work with custom key/values
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
+ ) ++ Seq(
+ // SPARK-5315 Spark Streaming Java API returns Scala DStream
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
)
case v if v.startsWith("1.2") =>
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 64f6a3ca6bf4c..568e21f3803bf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index a975dc19cb78e..a0a028446d5fd 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -111,10 +111,9 @@ def run(self):
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+ # TODO(davies): move into sql
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4977400ac1c05..f4cfe4845dc20 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -141,6 +141,17 @@ def id(self):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 1990323249cf6..7d7550c854b2f 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,15 +20,19 @@
- L{SQLContext}
Main entry point for SQL functionality.
- - L{SchemaRDD}
+ - L{DataFrame}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
+ addition to normal RDD operations, DataFrames also support SQL.
+ - L{GroupedDataFrame}
+ - L{Column}
+ Column is a DataFrame with a single column.
- L{Row}
A Row of data returned by a Spark SQL query.
- L{HiveContext}
Main entry point for accessing data stored in Apache Hive..
"""
+import sys
import itertools
import decimal
import datetime
@@ -36,6 +40,9 @@
import warnings
import json
import re
+import random
+import os
+from tempfile import NamedTemporaryFile
from array import array
from operator import itemgetter
from itertools import imap
@@ -43,6 +50,7 @@
from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
+from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
@@ -54,7 +62,8 @@
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "SchemaRDD", "Row"]
+ "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+ "SchemaRDD"]
class DataType(object):
@@ -1171,7 +1180,7 @@ def Dict(d):
class Row(tuple):
- """ Row in SchemaRDD """
+ """ Row in DataFrame """
__DATATYPE__ = dataType
__FIELDS__ = tuple(f.name for f in dataType.fields)
__slots__ = ()
@@ -1198,7 +1207,7 @@ class SQLContext(object):
"""Main entry point for Spark SQL functionality.
- A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as
+ A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
tables, execute SQL over tables, cache tables, and read parquet files.
"""
@@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
@@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None):
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> srdd = sqlCtx.inferSchema(allTypes)
- >>> srdd.registerTempTable("allTypes")
+ >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None):
... [Row(field1=1, field2="row1"),
... Row(field1=2, field2="row2"),
... Row(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
>>> NestedRow = Row("f1", "f2")
>>> nestedRdd1 = sc.parallelize([
... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> srdd = sqlCtx.inferSchema(nestedRdd1)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd1)
+ >>> df.collect()
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
>>> nestedRdd2 = sc.parallelize([
... NestedRow([[1, 2], [2, 3]], [1, 2]),
... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> srdd = sqlCtx.inferSchema(nestedRdd2)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd2)
+ >>> df.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
>>> from collections import namedtuple
@@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None):
... [CustomRow(field1=1, field2="row1"),
... CustomRow(field1=2, field2="row2"),
... CustomRow(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
first = rdd.first()
if not first:
@@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema):
>>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
- >>> srdd = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT * from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.applySchema(rdd2, schema)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT * from table1")
+ >>> df2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
>>> from datetime import date, datetime
@@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema):
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
- >>> srdd = sqlCtx.applySchema(rdd, schema)
- >>> results = srdd.map(
+ >>> df = sqlCtx.applySchema(rdd, schema)
+ >>> results = df.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
- >>> srdd.registerTempTable("table2")
+ >>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
@@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema):
>>> abstract = "byte short float time map{} struct(b) list[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
- >>> srdd.collect()
+ >>> df = sqlCtx.applySchema(rdd, typedSchema)
+ >>> df.collect()
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema):
rdd = rdd.map(converter)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName):
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
"""
- if (rdd.__class__ is SchemaRDD):
- srdd = rdd._jschema_rdd.baseSchemaRDD()
- self._ssql_ctx.registerRDDAsTable(srdd, tableName)
+ if (rdd.__class__ is DataFrame):
+ df = rdd._jdf
+ self._ssql_ctx.registerRDDAsTable(df, tableName)
else:
- raise ValueError("Can only register SchemaRDD as table")
+ raise ValueError("Can only register DataFrame as table")
def parquetFile(self, path):
- """Loads a Parquet file, returning the result as a L{SchemaRDD}.
+ """Loads a Parquet file, returning the result as a L{DataFrame}.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- jschema_rdd = self._ssql_ctx.parquetFile(path)
- return SchemaRDD(jschema_rdd, self)
+ jdf = self._ssql_ctx.parquetFile(path)
+ return DataFrame(jdf, self)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
- L{SchemaRDD}.
+ L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
@@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
+ df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return DataFrame(df, self)
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
- """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
+ """Loads an RDD storing one JSON object per string as a L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
@@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
@@ -1615,33 +1624,33 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return DataFrame(df, self)
def sql(self, sqlQuery):
- """Return a L{SchemaRDD} representing the result of the given query.
+ """Return a L{DataFrame} representing the result of the given query.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
def table(self, tableName):
- """Returns the specified table as a L{SchemaRDD}.
+ """Returns the specified table as a L{DataFrame}.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.table("table1")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return SchemaRDD(self._ssql_ctx.table(tableName), self)
+ return DataFrame(self._ssql_ctx.table(tableName), self)
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
@@ -1707,7 +1716,7 @@ def _create_row(fields, values):
class Row(tuple):
"""
- A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
@@ -1799,111 +1808,119 @@ def inherit_doc(cls):
return cls
-@inherit_doc
-class SchemaRDD(RDD):
+class DataFrame(object):
- """An RDD of L{Row} objects that has an associated schema.
+ """A collection of rows that have the same columns.
- The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
- utilize the relational query api exposed by Spark SQL.
+ A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+ and can be created using various functions in :class:`SQLContext`::
- For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
- L{SchemaRDD} is not operated on directly, as it's underlying
- implementation is an RDD composed of Java objects. Instead it is
- converted to a PythonRDD in the JVM, on which Python operations can
- be done.
+ people = sqlContext.parquetFile("...")
- This class receives raw tuples from Java but assigns a class to it in
- all its data-collection methods (mapPartitionsWithIndex, collect, take,
- etc) so that PySpark sees them as Row objects with named fields.
+ Once created, it can be manipulated using the various domain-specific-language
+ (DSL) functions defined in: [[DataFrame]], [[Column]].
+
+ To select a column from the data frame, use the apply method::
+
+ ageCol = people.age
+
+ Note that the :class:`Column` type can also be manipulated
+ through its various functions::
+
+ # The following creates a new column that increases everybody's age by 10.
+ people.age + 10
+
+
+ A more concrete example::
+
+ # To create DataFrame using SQLContext
+ people = sqlContext.parquetFile("...")
+ department = sqlContext.parquetFile("...")
+
+ people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
"""
- def __init__(self, jschema_rdd, sql_ctx):
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
self.sql_ctx = sql_ctx
- self._sc = sql_ctx._sc
- clsName = jschema_rdd.getClass().getName()
- assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD"
- self._jschema_rdd = jschema_rdd
- self._id = None
+ self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
- self.is_checkpointed = False
- self.ctx = self.sql_ctx._sc
- # the _jrdd is created by javaToPython(), serialized by pickle
- self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer())
@property
- def _jrdd(self):
- """Lazy evaluation of PythonRDD object.
+ def rdd(self):
+ """Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row`s. """
+ if not hasattr(self, '_lazy_rdd'):
+ jrdd = self._jdf.javaToPython()
+ rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+ schema = self.schema()
- Only done when a user calls methods defined by the
- L{pyspark.rdd.RDD} super class (map, filter, etc.).
- """
- if not hasattr(self, '_lazy_jrdd'):
- self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
- return self._lazy_jrdd
+ def applySchema(it):
+ cls = _create_cls(schema)
+ return itertools.imap(cls, it)
- def id(self):
- if self._id is None:
- self._id = self._jrdd.id()
- return self._id
+ self._lazy_rdd = rdd.mapPartitions(applySchema)
+
+ return self._lazy_rdd
def limit(self, num):
"""Limit the result count to the number specified.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.limit(2).collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.limit(2).collect()
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
- >>> srdd.limit(0).collect()
+ >>> df.limit(0).collect()
[]
"""
- rdd = self._jschema_rdd.baseSchemaRDD().limit(num)
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
def toJSON(self, use_unicode=False):
- """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+ """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
- >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql( "SELECT * from table1")
+ >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
True
- >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
- >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+ >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+ >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
True
"""
- rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+ rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a SchemaRDD using the L{SQLContext.parquetFile} method.
+ a DataFrame using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd2.collect()) == sorted(srdd.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df2.collect()) == sorted(df.collect())
True
"""
- self._jschema_rdd.saveAsParquetFile(path)
+ self._jdf.saveAsParquetFile(path)
def registerTempTable(self, name):
"""Registers this RDD as a temporary table using the given name.
The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this SchemaRDD.
+ that was used to create this DataFrame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.registerTempTable("test")
- >>> srdd2 = sqlCtx.sql("select * from test")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.registerTempTable("test")
+ >>> df2 = sqlCtx.sql("select * from test")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- self._jschema_rdd.registerTempTable(name)
+ self._jdf.registerTempTable(name)
def registerAsTable(self, name):
"""DEPRECATED: use registerTempTable() instead"""
@@ -1911,62 +1928,61 @@ def registerAsTable(self, name):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this SchemaRDD into the specified table.
+ """Inserts the contents of this DataFrame into the specified table.
Optionally overwriting any existing data.
"""
- self._jschema_rdd.insertInto(tableName, overwrite)
+ self._jdf.insertInto(tableName, overwrite)
def saveAsTable(self, tableName):
- """Creates a new table with the contents of this SchemaRDD."""
- self._jschema_rdd.saveAsTable(tableName)
+ """Creates a new table with the contents of this DataFrame."""
+ self._jdf.saveAsTable(tableName)
def schema(self):
- """Returns the schema of this SchemaRDD (represented by
+ """Returns the schema of this DataFrame (represented by
a L{StructType})."""
- return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
-
- def schemaString(self):
- """Returns the output schema in the tree format."""
- return self._jschema_rdd.schemaString()
+ return _parse_datatype_json_string(self._jdf.schema().json())
def printSchema(self):
"""Prints out the schema in the tree format."""
- print self.schemaString()
+ print (self._jdf.schema().treeString())
def count(self):
"""Return the number of elements in this RDD.
Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the SchemaRDD,
+ leverages the query optimizer to compute the count on the DataFrame,
which supports features such as filter pushdown.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.count()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.count()
3L
- >>> srdd.count() == srdd.map(lambda x: x).count()
+ >>> df.count() == df.map(lambda x: x).count()
True
"""
- return self._jschema_rdd.count()
+ return self._jdf.count()
def collect(self):
- """Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows.
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of collect, this implementation
- leverages the query optimizer to perform a collect on the SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+ with SCCallSiteSync(self._sc) as css:
+ bytesInJava = self._jdf.javaToPython().collect().iterator()
cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+ tempFile.close()
+ self._sc._writeToFile(bytesInJava, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+ os.unlink(tempFile.name)
+ return [cls(r) for r in rs]
def take(self, num):
"""Take the first num rows of the RDD.
@@ -1974,130 +1990,555 @@ def take(self, num):
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of take, this implementation
- leverages the query optimizer to perform a collect on a SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.take(2)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.take(2)
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
"""
return self.limit(num).collect()
- # Convert each object in the RDD to a Row with the right class
- # for this SchemaRDD, so that fields can be accessed as attributes.
- def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ def map(self, f):
+ """ Return a new RDD by applying a function to each Row, it's a
+ shorthand for df.rdd.map()
"""
- Return a new RDD by applying a function to each partition of this RDD,
- while tracking the index of the original partition.
+ return self.rdd.map(f)
- >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
- >>> def f(splitIndex, iterator): yield splitIndex
- >>> rdd.mapPartitionsWithIndex(f).sum()
- 6
+ def mapPartitions(self, f, preservesPartitioning=False):
"""
- rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
-
- schema = self.schema()
+ Return a new RDD by applying a function to each partition.
- def applySchema(_, it):
- cls = _create_cls(schema)
- return itertools.imap(cls, it)
-
- objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
- return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(iterator): yield 1
+ >>> rdd.mapPartitions(f).sum()
+ 4
+ """
+ return self.rdd.mapPartitions(f, preservesPartitioning)
- # We override the default cache/persist/checkpoint behavior
- # as we want to cache the underlying SchemaRDD object in the JVM,
- # not the PythonRDD checkpointed by the super class
def cache(self):
+ """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- self._jschema_rdd.cache()
+ self._jdf.cache()
return self
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+ """ Set the storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
- self._jschema_rdd.persist(javaStorageLevel)
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
return self
def unpersist(self, blocking=True):
+ """ Mark it as non-persistent, and remove all blocks for it from
+ memory and disk.
+ """
self.is_cached = False
- self._jschema_rdd.unpersist(blocking)
+ self._jdf.unpersist(blocking)
return self
- def checkpoint(self):
- self.is_checkpointed = True
- self._jschema_rdd.checkpoint()
+ # def coalesce(self, numPartitions, shuffle=False):
+ # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+ # return DataFrame(rdd, self.sql_ctx)
- def isCheckpointed(self):
- return self._jschema_rdd.isCheckpointed()
+ def repartition(self, numPartitions):
+ """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+ partitions.
+ """
+ rdd = self._jdf.repartition(numPartitions, None)
+ return DataFrame(rdd, self.sql_ctx)
- def getCheckpointFile(self):
- checkpointFile = self._jschema_rdd.getCheckpointFile()
- if checkpointFile.isDefined():
- return checkpointFile.get()
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this DataFrame.
- def coalesce(self, numPartitions, shuffle=False):
- rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+ return DataFrame(rdd, self.sql_ctx)
+
+ # def takeSample(self, withReplacement, num, seed=None):
+ # """Return a fixed-size sampled subset of this DataFrame.
+ #
+ # >>> df = sqlCtx.inferSchema(rdd)
+ # >>> df.takeSample(False, 2, 97)
+ # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ # """
+ # seed = seed if seed is not None else random.randint(0, sys.maxint)
+ # with SCCallSiteSync(self.context) as css:
+ # bytesInJava = self._jdf \
+ # .takeSampleToPython(withReplacement, num, long(seed)) \
+ # .iterator()
+ # cls = _create_cls(self.schema())
+ # return map(cls, self._collect_iterator_through_file(bytesInJava))
- def distinct(self, numPartitions=None):
- if numPartitions is None:
- rdd = self._jschema_rdd.distinct()
+ @property
+ def dtypes(self):
+ """Return all column names and their data types as a list.
+ """
+ return [(f.name, str(f.dataType)) for f in self.schema().fields]
+
+ @property
+ def columns(self):
+ """ Return all column names as a list.
+ """
+ return [f.name for f in self.schema().fields]
+
+ def show(self):
+ raise NotImplemented
+
+ def join(self, other, joinExprs=None, joinType=None):
+ """
+ Join with another DataFrame, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`::
+
+ df1.join(df2, df1.key == df2.key, "outer")
+
+ :param other: Right side of the join
+ :param joinExprs: Join expression
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
+ `semijoin`.
+ """
+ if joinType is None:
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
+ else:
+ jdf = self._jdf.join(other._jdf, joinExprs)
else:
- rdd = self._jschema_rdd.distinct(numPartitions, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def sort(self, *cols):
+ """ Return a new [[DataFrame]] sorted by the specified column,
+ in ascending column.
+
+ :param cols: The columns or expressions used for sorting
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ for i, c in enumerate(cols):
+ if isinstance(c, basestring):
+ cols[i] = Column(c)
+ jcols = [c._jc for c in cols]
+ jdf = self._jdf.join(*jcols)
+ return DataFrame(jdf, self.sql_ctx)
+
+ sortBy = sort
+
+ def head(self, n=None):
+ """ Return the first `n` rows or the first row if n is None. """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def tail(self):
+ raise NotImplemented
+
+ def __getitem__(self, item):
+ if isinstance(item, basestring):
+ return Column(self._jdf.apply(item))
+
+ # TODO projection
+ raise IndexError
+
+ def __getattr__(self, name):
+ """ Return the column by given name """
+ if isinstance(name, basestring):
+ return Column(self._jdf.apply(name))
+ raise AttributeError
+
+ def As(self, name):
+ """ Alias the current DataFrame """
+ return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+
+ def select(self, *cols):
+ """ Selecting a set of expressions.::
+
+ df.select()
+ df.select('colA', 'colB')
+ df.select(df.colA, df.colB + 1)
- def intersection(self, other):
- if (other.__class__ is SchemaRDD):
- rdd = self._jschema_rdd.intersection(other._jschema_rdd)
- return SchemaRDD(rdd, self.sql_ctx)
+ """
+ if not cols:
+ cols = ["*"]
+ if isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only intersect with another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+ return DataFrame(jdf, self.sql_ctx)
- def repartition(self, numPartitions):
- rdd = self._jschema_rdd.repartition(numPartitions, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ def filter(self, condition):
+ """ Filtering rows using the given condition::
- def subtract(self, other, numPartitions=None):
- if (other.__class__ is SchemaRDD):
- if numPartitions is None:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd)
- else:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd,
- numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ df.filter(df.age > 15)
+ df.where(df.age > 15)
+
+ """
+ return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+
+ where = filter
+
+ def groupBy(self, *cols):
+ """ Group the [[DataFrame]] using the specified columns,
+ so we can run aggregation on them. See :class:`GroupedDataFrame`
+ for all the available aggregate functions::
+
+ df.groupBy(df.department).avg()
+ df.groupBy("department", "gender").agg({
+ "salary": "avg",
+ "age": "max",
+ })
+ """
+ if cols and isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only subtract another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+ return GroupedDataFrame(jdf, self.sql_ctx)
- def sample(self, withReplacement, fraction, seed=None):
+ def agg(self, *exprs):
+ """ Aggregate on the entire [[DataFrame]] without groups
+ (shorthand for df.groupBy.agg())::
+
+ df.agg({"age": "max", "salary": "avg"})
"""
- Return a sampled subset of this SchemaRDD.
+ return self.groupBy().agg(*exprs)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.sample(False, 0.5, 97).count()
- 2L
+ def unionAll(self, other):
+ """ Return a new DataFrame containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL.
"""
- assert fraction >= 0.0, "Negative fraction value: %s" % fraction
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
- return SchemaRDD(rdd, self.sql_ctx)
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
- def takeSample(self, withReplacement, num, seed=None):
- """Return a fixed-size sampled subset of this SchemaRDD.
+ def intersect(self, other):
+ """ Return a new [[DataFrame]] containing rows only in
+ both this frame and another frame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.takeSample(False, 2, 97)
- [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ This is equivalent to `INTERSECT` in SQL.
"""
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD() \
- .takeSampleToPython(withReplacement, num, long(seed)) \
- .iterator()
- cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+ def Except(self, other):
+ """ Return a new [[DataFrame]] containing rows in this frame
+ but not in another frame.
+
+ This is equivalent to `EXCEPT` in SQL.
+ """
+ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """ Return a new DataFrame by sampling a fraction of rows. """
+ if seed is None:
+ jdf = self._jdf.sample(withReplacement, fraction)
+ else:
+ jdf = self._jdf.sample(withReplacement, fraction, seed)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def addColumn(self, colName, col):
+ """ Return a new [[DataFrame]] by adding a column. """
+ return self.select('*', col.As(colName))
+
+ def removeColumn(self, colName):
+ raise NotImplemented
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+ """
+ SchemaRDD is deprecated, please use DataFrame
+ """
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedDataFrame(object):
+
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by DataFrame.groupBy().
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ def agg(self, *exprs):
+ """ Compute aggregates by specifying a map from column name
+ to aggregate methods.
+
+ The available aggregate methods are `avg`, `max`, `min`,
+ `sum`, `count`.
+
+ :param exprs: list or aggregate columns or a map from column
+ name to agregate methods.
+ """
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jmap = MapConverter().convert(exprs[0],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(jmap)
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
+ jdf = self._jdf.agg(*exprs)
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """ Count the number of rows for each group. """
+
+ @dfapi
+ def mean(self):
+ """Compute the average value for each numeric columns
+ for each group. This is an alias for `avg`."""
+
+ @dfapi
+ def avg(self):
+ """Compute the average value for each numeric columns
+ for each group."""
+
+ @dfapi
+ def max(self):
+ """Compute the max value for each numeric columns for
+ each group. """
+
+ @dfapi
+ def min(self):
+ """Compute the min value for each numeric column for
+ each group."""
+
+ @dfapi
+ def sum(self):
+ """Compute the sum for each numeric columns for each
+ group."""
+
+
+SCALA_METHOD_MAPPINGS = {
+ '=': '$eq',
+ '>': '$greater',
+ '<': '$less',
+ '+': '$plus',
+ '-': '$minus',
+ '*': '$times',
+ '/': '$div',
+ '!': '$bang',
+ '@': '$at',
+ '#': '$hash',
+ '%': '$percent',
+ '^': '$up',
+ '&': '$amp',
+ '~': '$tilde',
+ '?': '$qmark',
+ '|': '$bar',
+ '\\': '$bslash',
+ ':': '$colon',
+}
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Literal.apply(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Column(name)
+
+
+def _scalaMethod(name):
+ """ Translate operators into methodName in Scala
+
+ For example:
+ >>> _scalaMethod('+')
+ '$plus'
+ >>> _scalaMethod('>=')
+ '$greater$eq'
+ >>> _scalaMethod('cast')
+ 'cast'
+ """
+ return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
+
+
+def _unary_op(name):
+ """ Create a method for given unary operator """
+ def _(self):
+ return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+ return _
+
+
+def _bin_op(name):
+ """ Create a method for given binary operator """
+ def _(self, other):
+ if isinstance(other, Column):
+ jc = other._jc
+ else:
+ jc = _create_column_from_literal(other)
+ return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+ return _
+
+
+def _reverse_op(name):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
+ self._jdf, self.sql_ctx)
+ return _
+
+
+class Column(DataFrame):
+
+ """
+ A column in a DataFrame.
+
+ `Column` instances can be created by:
+ {{{
+ // 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
+
+ // 2. Create from an expression
+ df["colName"] + 1
+ }}}
+ """
+
+ def __init__(self, jc, jdf=None, sql_ctx=None):
+ self._jc = jc
+ super(Column, self).__init__(jdf, sql_ctx)
+
+ # arithmetic operators
+ __neg__ = _unary_op("unary_-")
+ __add__ = _bin_op("+")
+ __sub__ = _bin_op("-")
+ __mul__ = _bin_op("*")
+ __div__ = _bin_op("/")
+ __mod__ = _bin_op("%")
+ __radd__ = _bin_op("+")
+ __rsub__ = _reverse_op("-")
+ __rmul__ = _bin_op("*")
+ __rdiv__ = _reverse_op("/")
+ __rmod__ = _reverse_op("%")
+ __abs__ = _unary_op("abs")
+ abs = _unary_op("abs")
+ sqrt = _unary_op("sqrt")
+
+ # logistic operators
+ __eq__ = _bin_op("===")
+ __ne__ = _bin_op("!==")
+ __lt__ = _bin_op("<")
+ __le__ = _bin_op("<=")
+ __ge__ = _bin_op(">=")
+ __gt__ = _bin_op(">")
+ # `and`, `or`, `not` cannot be overloaded in Python
+ And = _bin_op('&&')
+ Or = _bin_op('||')
+ Not = _unary_op('unary_!')
+
+ # bitwise operators
+ __and__ = _bin_op("&")
+ __or__ = _bin_op("|")
+ __invert__ = _unary_op("unary_~")
+ __xor__ = _bin_op("^")
+ # __lshift__ = _bin_op("<<")
+ # __rshift__ = _bin_op(">>")
+ __rand__ = _bin_op("&")
+ __ror__ = _bin_op("|")
+ __rxor__ = _bin_op("^")
+ # __rlshift__ = _reverse_op("<<")
+ # __rrshift__ = _reverse_op(">>")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("getItem")
+ # __getattr__ = _bin_op("getField")
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+ upper = _unary_op("upper")
+ lower = _unary_op("lower")
+
+ def substr(self, startPos, pos):
+ if type(startPos) != type(pos):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+
+ jc = self._jc.substr(startPos, pos)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, pos._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc, self._jdf, self.sql_ctx)
+
+ __getslice__ = substr
+
+ # order
+ asc = _unary_op("asc")
+ desc = _unary_op("desc")
+
+ isNull = _unary_op("isNull")
+ isNotNull = _unary_op("isNotNull")
+
+ # `as` is keyword
+ def As(self, alias):
+ return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+
+ def cast(self, dataType):
+ if self.sql_ctx is None:
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ else:
+ ssql_ctx = self.sql_ctx._ssql_ctx
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+
+
+def _aggregate_func(name):
+ """ Creat a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ # FIXME: can not access dsl.min/max ...
+ jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
+ return Column(jc)
+ return staticmethod(_)
+
+
+class Aggregator(object):
+ """
+ A collections of builtin aggregators
+ """
+ max = _aggregate_func("max")
+ min = _aggregate_func("min")
+ avg = mean = _aggregate_func("mean")
+ sum = _aggregate_func("sum")
+ first = _aggregate_func("first")
+ last = _aggregate_func("last")
+ count = _aggregate_func("count")
def _test():
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b474fcf5bfb7e..e8e207af462de 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -806,6 +806,9 @@ def tearDownClass(cls):
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
+ self.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(self.testData)
+ self.df = self.sqlCtx.inferSchema(rdd)
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
@@ -821,7 +824,7 @@ def test_udf2(self):
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -839,68 +842,51 @@ def test_broadcast_in_udf(self):
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- srdd = self.sqlCtx.jsonRDD(rdd)
- srdd.count()
- srdd.collect()
- srdd.schemaString()
- srdd.schema()
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
# cache and checkpoint
- self.assertFalse(srdd.is_cached)
- srdd.persist()
- srdd.unpersist()
- srdd.cache()
- self.assertTrue(srdd.is_cached)
- self.assertFalse(srdd.isCheckpointed())
- self.assertEqual(None, srdd.getCheckpointFile())
-
- srdd = srdd.coalesce(2, True)
- srdd = srdd.repartition(3)
- srdd = srdd.distinct()
- srdd.intersection(srdd)
- self.assertEqual(2, srdd.count())
-
- srdd.registerTempTable("temp")
- srdd = self.sqlCtx.sql("select foo from temp")
- srdd.count()
- srdd.collect()
-
- def test_distinct(self):
- rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
- srdd = self.sqlCtx.jsonRDD(rdd)
- self.assertEquals(srdd.getNumPartitions(), 10)
- self.assertEquals(srdd.distinct().count(), 3)
- result = srdd.distinct(5)
- self.assertEquals(result.getNumPartitions(), 5)
- self.assertEquals(result.count(), 3)
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
def test_apply_schema_to_row(self):
- srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
- self.assertEqual(srdd.collect(), srdd2.collect())
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
- self.assertEqual(10, srdd3.count())
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- row = srdd.first()
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)
- l = srdd.map(lambda x: x.l).first()
+ l = df.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)
- d = srdd.map(lambda x: x.d).first()
+ d = df.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)
- row = srdd.map(lambda x: x.d["key"]).first()
+ row = df.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
@@ -908,26 +894,26 @@ def test_infer_schema(self):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- self.assertEqual([], srdd.map(lambda r: r.l).first())
- self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
- srdd.registerTempTable("test")
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
- srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(srdd.schema(), srdd2.schema())
- self.assertEqual({}, srdd2.map(lambda r: r.d).first())
- self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
- srdd2.registerTempTable("test2")
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- k, v = srdd.first().m.items()[0]
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -935,9 +921,9 @@ def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").first()
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
@@ -945,12 +931,12 @@ def test_infer_schema_with_udt(self):
from pyspark.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
- srdd.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
@@ -959,21 +945,52 @@ def test_apply_schema_with_udt(self):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- srdd = self.sqlCtx.applySchema(rdd, schema)
- point = srdd.first().point
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.inferSchema(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- srdd0.saveAsParquetFile(output_dir)
- srdd1 = self.sqlCtx.parquetFile(output_dir)
- point = srdd1.first().point
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbit))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+ # TODO(davies): fix aggregators
+ from pyspark.sql import Aggregator as Agg
+ # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+
class InputFormatTests(ReusedPySparkTestCase):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
index 22941edef2d46..4c5fb3f45bf49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] {
.toSet
plan transform {
- case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance
+ case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 3035d934ff9f8..f388cd5972bac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression {
* For example the SQL expression "1 + 1 AS a" could be represented as follows:
* Alias(Add(Literal(1), Literal(1), "a")()
*
+ * Note that exprId and qualifiers are in a separate parameter list because
+ * we only pattern match on child and name.
+ *
* @param child the computation being performed
* @param name the name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 613f4bb09daf5..5dc0539caec24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -17,9 +17,24 @@
package org.apache.spark.sql.catalyst.plans
+object JoinType {
+ def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
+ case "inner" => Inner
+ case "outer" | "full" | "fullouter" => FullOuter
+ case "leftouter" | "left" => LeftOuter
+ case "rightouter" | "right" => RightOuter
+ case "leftsemi" => LeftSemi
+ }
+}
+
sealed abstract class JoinType
+
case object Inner extends JoinType
+
case object LeftOuter extends JoinType
+
case object RightOuter extends JoinType
+
case object FullOuter extends JoinType
+
case object LeftSemi extends JoinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
index 19769986ef58c..d90af45b375e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.types.{StructType, StructField}
object LocalRelation {
- def apply(output: Attribute*) =
- new LocalRelation(output)
+ def apply(output: Attribute*): LocalRelation = new LocalRelation(output)
+
+ def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation(
+ StructType(output1 +: output).toAttributes
+ )
}
case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index e715d9434a2ab..bc22f688338b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -80,7 +80,7 @@ private[sql] trait CacheManager {
* the in-memory columnar representation of the underlying table is expensive.
*/
private[sql] def cacheQuery(
- query: SchemaRDD,
+ query: DataFrame,
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
@@ -100,7 +100,7 @@ private[sql] trait CacheManager {
}
/** Removes the data for the given SchemaRDD from the cache */
- private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock {
+ private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
require(dataIndex >= 0, s"Table $query is not cached.")
@@ -110,7 +110,7 @@ private[sql] trait CacheManager {
/** Tries to remove the data for the given SchemaRDD from the cache if it's cached */
private[sql] def tryUncacheQuery(
- query: SchemaRDD,
+ query: DataFrame,
blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,7 +123,7 @@ private[sql] trait CacheManager {
}
/** Optionally returns cached data for the given SchemaRDD */
- private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
+ private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 0000000000000..7fc8347428df4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,528 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
+import org.apache.spark.sql.types._
+
+
+object Column {
+ def unapply(col: Column): Option[Expression] = Some(col.expr)
+
+ def apply(colName: String): Column = new Column(colName)
+}
+
+
+/**
+ * A column in a [[DataFrame]].
+ *
+ * `Column` instances can be created by:
+ * {{{
+ * // 1. Select a column out of a DataFrame
+ * df("colName")
+ *
+ * // 2. Create a literal expression
+ * Literal(1)
+ *
+ * // 3. Create new columns from
+ * }}}
+ *
+ */
+// TODO: Improve documentation.
+class Column(
+ sqlContext: Option[SQLContext],
+ plan: Option[LogicalPlan],
+ val expr: Expression)
+ extends DataFrame(sqlContext, plan) with ExpressionApi {
+
+ /** Turn a Catalyst expression into a `Column`. */
+ protected[sql] def this(expr: Expression) = this(None, None, expr)
+
+ /**
+ * Create a new `Column` expression based on a column or attribute name.
+ * The resolution of this is the same as SQL. For example:
+ *
+ * - "colName" becomes an expression selecting the column named "colName".
+ * - "*" becomes an expression selecting all columns.
+ * - "df.*" becomes an expression selecting all columns in data frame "df".
+ */
+ def this(name: String) = this(name match {
+ case "*" => Star(None)
+ case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2)))
+ case _ => UnresolvedAttribute(name)
+ })
+
+ override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined
+
+ /**
+ * An implicit conversion function internal to this class. This function creates a new Column
+ * based on an expression. If the expression itself is not named, it aliases the expression
+ * by calling it "col".
+ */
+ private[this] implicit def toColumn(expr: Expression): Column = {
+ val projectedPlan = plan.map { p =>
+ Project(Seq(expr match {
+ case named: NamedExpression => named
+ case unnamed: Expression => Alias(unnamed, "col")()
+ }), p)
+ }
+ new Column(sqlContext, projectedPlan, expr)
+ }
+
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Select the amount column and negates all values.
+ * df.select( -df("amount") )
+ * }}}
+ */
+ override def unary_- : Column = UnaryMinus(expr)
+
+ /**
+ * Bitwise NOT.
+ * {{{
+ * // Select the flags column and negate every bit.
+ * df.select( ~df("flags") )
+ * }}}
+ */
+ override def unary_~ : Column = BitwiseNot(expr)
+
+ /**
+ * Invert a boolean expression, i.e. NOT.
+ * {{
+ * // Select rows that are not active (isActive === false)
+ * df.select( !df("isActive") )
+ * }}
+ */
+ override def unary_! : Column = Not(expr)
+
+
+ /**
+ * Equality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA equals colB.
+ * df.select( df("colA") === df("colB") )
+ * df.select( df("colA".equalTo(df("colB")) )
+ * }}}
+ */
+ override def === (other: Column): Column = EqualTo(expr, other.expr)
+
+ /**
+ * Equality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA is "Zaharia".
+ * df.select( df("colA") === "Zaharia")
+ * df.select( df("colA".equalTo("Zaharia") )
+ * }}}
+ */
+ override def === (literal: Any): Column = this === Literal.anyToLiteral(literal)
+
+ /**
+ * Equality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA equals colB.
+ * df.select( df("colA") === df("colB") )
+ * df.select( df("colA".equalTo(df("colB")) )
+ * }}}
+ */
+ override def equalTo(other: Column): Column = this === other
+
+ /**
+ * Equality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA is "Zaharia".
+ * df.select( df("colA") === "Zaharia")
+ * df.select( df("colA".equalTo("Zaharia") )
+ * }}}
+ */
+ override def equalTo(literal: Any): Column = this === literal
+
+ /**
+ * Inequality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA does not equal colB.
+ * df.select( df("colA") !== df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ * }}}
+ */
+ override def !== (other: Column): Column = Not(EqualTo(expr, other.expr))
+
+ /**
+ * Inequality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA does not equal equal 15.
+ * df.select( df("colA") !== 15 )
+ * df.select( !(df("colA") === 15) )
+ * }}}
+ */
+ override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal)
+
+ /**
+ * Greater than an expression.
+ * {{{
+ * // The following selects people older than 21.
+ * people.select( people("age") > Literal(21) )
+ * }}}
+ */
+ override def > (other: Column): Column = GreaterThan(expr, other.expr)
+
+ /**
+ * Greater than a literal value.
+ * {{{
+ * // The following selects people older than 21.
+ * people.select( people("age") > 21 )
+ * }}}
+ */
+ override def > (literal: Any): Column = this > Literal.anyToLiteral(literal)
+
+ /**
+ * Less than an expression.
+ * {{{
+ * // The following selects people younger than 21.
+ * people.select( people("age") < Literal(21) )
+ * }}}
+ */
+ override def < (other: Column): Column = LessThan(expr, other.expr)
+
+ /**
+ * Less than a literal value.
+ * {{{
+ * // The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ * }}}
+ */
+ override def < (literal: Any): Column = this < Literal.anyToLiteral(literal)
+
+ /**
+ * Less than or equal to an expression.
+ * {{{
+ * // The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= Literal(21) )
+ * }}}
+ */
+ override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr)
+
+ /**
+ * Less than or equal to a literal value.
+ * {{{
+ * // The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ * }}}
+ */
+ override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal)
+
+ /**
+ * Greater than or equal to an expression.
+ * {{{
+ * // The following selects people age 21 or older than 21.
+ * people.select( people("age") >= Literal(21) )
+ * }}}
+ */
+ override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr)
+
+ /**
+ * Greater than or equal to a literal value.
+ * {{{
+ * // The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ * }}}
+ */
+ override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal)
+
+ /**
+ * Equality test with an expression that is safe for null values.
+ */
+ override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr)
+
+ /**
+ * Equality test with a literal value that is safe for null values.
+ */
+ override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal)
+
+ /**
+ * True if the current expression is null.
+ */
+ override def isNull: Column = IsNull(expr)
+
+ /**
+ * True if the current expression is NOT null.
+ */
+ override def isNotNull: Column = IsNotNull(expr)
+
+ /**
+ * Boolean OR with an expression.
+ * {{{
+ * // The following selects people that are in school or employed.
+ * people.select( people("inSchool") || people("isEmployed") )
+ * }}}
+ */
+ override def || (other: Column): Column = Or(expr, other.expr)
+
+ /**
+ * Boolean OR with a literal value.
+ * {{{
+ * // The following selects everything.
+ * people.select( people("inSchool") || true )
+ * }}}
+ */
+ override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal)
+
+ /**
+ * Boolean AND with an expression.
+ * {{{
+ * // The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ * }}}
+ */
+ override def && (other: Column): Column = And(expr, other.expr)
+
+ /**
+ * Boolean AND with a literal value.
+ * {{{
+ * // The following selects people that are in school.
+ * people.select( people("inSchool") && true )
+ * }}}
+ */
+ override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise AND with an expression.
+ */
+ override def & (other: Column): Column = BitwiseAnd(expr, other.expr)
+
+ /**
+ * Bitwise AND with a literal value.
+ */
+ override def & (literal: Any): Column = this & Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise OR with an expression.
+ */
+ override def | (other: Column): Column = BitwiseOr(expr, other.expr)
+
+ /**
+ * Bitwise OR with a literal value.
+ */
+ override def | (literal: Any): Column = this | Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise XOR with an expression.
+ */
+ override def ^ (other: Column): Column = BitwiseXor(expr, other.expr)
+
+ /**
+ * Bitwise XOR with a literal value.
+ */
+ override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal)
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ * }}}
+ */
+ override def + (other: Column): Column = Add(expr, other.expr)
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // The following selects the sum of a person's height and 10.
+ * people.select( people("height") + 10 )
+ * }}}
+ */
+ override def + (literal: Any): Column = this + Literal.anyToLiteral(literal)
+
+ /**
+ * Subtraction. Substract the other expression from this expression.
+ * {{{
+ * // The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ * }}}
+ */
+ override def - (other: Column): Column = Subtract(expr, other.expr)
+
+ /**
+ * Subtraction. Substract a literal value from this expression.
+ * {{{
+ * // The following selects a person's height and substract it by 10.
+ * people.select( people("height") - 10 )
+ * }}}
+ */
+ override def - (literal: Any): Column = this - Literal.anyToLiteral(literal)
+
+ /**
+ * Multiply this expression and another expression.
+ * {{{
+ * // The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ * }}}
+ */
+ override def * (other: Column): Column = Multiply(expr, other.expr)
+
+ /**
+ * Multiply this expression and a literal value.
+ * {{{
+ * // The following multiplies a person's height by 10.
+ * people.select( people("height") * 10 )
+ * }}}
+ */
+ override def * (literal: Any): Column = this * Literal.anyToLiteral(literal)
+
+ /**
+ * Divide this expression by another expression.
+ * {{{
+ * // The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ * }}}
+ */
+ override def / (other: Column): Column = Divide(expr, other.expr)
+
+ /**
+ * Divide this expression by a literal value.
+ * {{{
+ * // The following divides a person's height by 10.
+ * people.select( people("height") / 10 )
+ * }}}
+ */
+ override def / (literal: Any): Column = this / Literal.anyToLiteral(literal)
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ */
+ override def % (other: Column): Column = Remainder(expr, other.expr)
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ */
+ override def % (literal: Any): Column = this % Literal.anyToLiteral(literal)
+
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the evaluated values of the arguments.
+ */
+ @scala.annotation.varargs
+ override def in(list: Column*): Column = In(expr, list.map(_.expr))
+
+ override def like(other: Column): Column = Like(expr, other.expr)
+
+ override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal))
+
+ override def rlike(other: Column): Column = RLike(expr, other.expr)
+
+ override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal))
+
+
+ override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
+
+ override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr)
+
+ override def getField(fieldName: String): Column = GetField(expr, fieldName)
+
+
+ override def substr(startPos: Column, len: Column): Column =
+ Substring(expr, startPos.expr, len.expr)
+
+ override def substr(startPos: Int, len: Int): Column =
+ this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len))
+
+ override def contains(other: Column): Column = Contains(expr, other.expr)
+
+ override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal))
+
+
+ override def startsWith(other: Column): Column = StartsWith(expr, other.expr)
+
+ override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal))
+
+ override def endsWith(other: Column): Column = EndsWith(expr, other.expr)
+
+ override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal))
+
+ override def as(alias: String): Column = Alias(expr, alias)()
+
+ override def cast(to: DataType): Column = Cast(expr, to)
+
+ override def desc: Column = SortOrder(expr, Descending)
+
+ override def asc: Column = SortOrder(expr, Ascending)
+}
+
+
+class ColumnName(name: String) extends Column(name) {
+
+ /** Creates a new AttributeReference of type boolean */
+ def boolean: StructField = StructField(name, BooleanType)
+
+ /** Creates a new AttributeReference of type byte */
+ def byte: StructField = StructField(name, ByteType)
+
+ /** Creates a new AttributeReference of type short */
+ def short: StructField = StructField(name, ShortType)
+
+ /** Creates a new AttributeReference of type int */
+ def int: StructField = StructField(name, IntegerType)
+
+ /** Creates a new AttributeReference of type long */
+ def long: StructField = StructField(name, LongType)
+
+ /** Creates a new AttributeReference of type float */
+ def float: StructField = StructField(name, FloatType)
+
+ /** Creates a new AttributeReference of type double */
+ def double: StructField = StructField(name, DoubleType)
+
+ /** Creates a new AttributeReference of type string */
+ def string: StructField = StructField(name, StringType)
+
+ /** Creates a new AttributeReference of type date */
+ def date: StructField = StructField(name, DateType)
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal: StructField = StructField(name, DecimalType.Unlimited)
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal(precision: Int, scale: Int): StructField =
+ StructField(name, DecimalType(precision, scale))
+
+ /** Creates a new AttributeReference of type timestamp */
+ def timestamp: StructField = StructField(name, TimestampType)
+
+ /** Creates a new AttributeReference of type binary */
+ def binary: StructField = StructField(name, BinaryType)
+
+ /** Creates a new AttributeReference of type array */
+ def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType))
+
+ /** Creates a new AttributeReference of type map */
+ def map(keyType: DataType, valueType: DataType): StructField =
+ map(MapType(keyType, valueType))
+
+ def map(mapType: MapType): StructField = StructField(name, mapType)
+
+ /** Creates a new AttributeReference of type struct */
+ def struct(fields: StructField*): StructField = struct(StructType(fields))
+
+ def struct(structType: StructType): StructField = StructField(name, structType)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
new file mode 100644
index 0000000000000..d0bb3640f8c1c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -0,0 +1,596 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
+import java.util.{ArrayList, List => JList}
+
+import com.fasterxml.jackson.core.JsonFactory
+import net.razorvine.pickle.Pickler
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A collection of rows that have the same columns.
+ *
+ * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using
+ * various functions in [[SQLContext]].
+ * {{{
+ * val people = sqlContext.parquetFile("...")
+ * }}}
+ *
+ * Once created, it can be manipulated using the various domain-specific-language (DSL) functions
+ * defined in: [[DataFrame]] (this class), [[Column]], and [[dsl]] for Scala DSL.
+ *
+ * To select a column from the data frame, use the apply method:
+ * {{{
+ * val ageCol = people("age") // in Scala
+ * Column ageCol = people.apply("age") // in Java
+ * }}}
+ *
+ * Note that the [[Column]] type can also be manipulated through its various functions.
+ * {{
+ * // The following creates a new column that increases everybody's age by 10.
+ * people("age") + 10 // in Scala
+ * }}
+ *
+ * A more concrete example:
+ * {{{
+ * // To create DataFrame using SQLContext
+ * val people = sqlContext.parquetFile("...")
+ * val department = sqlContext.parquetFile("...")
+ *
+ * people.filter("age" > 30)
+ * .join(department, people("deptId") === department("id"))
+ * .groupBy(department("name"), "gender")
+ * .agg(avg(people("salary")), max(people("age")))
+ * }}}
+ */
+// TODO: Improve documentation.
+class DataFrame protected[sql](
+ val sqlContext: SQLContext,
+ private val baseLogicalPlan: LogicalPlan,
+ operatorsEnabled: Boolean)
+ extends DataFrameSpecificApi with RDDApi[Row] {
+
+ protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) =
+ this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined)
+
+ protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true)
+
+ @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
+
+ @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ baseLogicalPlan
+ }
+
+ /**
+ * An implicit conversion function internal to this class for us to avoid doing
+ * "new DataFrame(...)" everywhere.
+ */
+ private[this] implicit def toDataFrame(logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrame(sqlContext, logicalPlan, true)
+ }
+
+ /** Return the list of numeric columns, useful for doing aggregation. */
+ protected[sql] def numericColumns: Seq[Expression] = {
+ schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
+ logicalPlan.resolve(n.name, sqlContext.analyzer.resolver).get
+ }
+ }
+
+ /** Resolve a column name into a Catalyst [[NamedExpression]]. */
+ protected[sql] def resolve(colName: String): NamedExpression = {
+ logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
+ throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
+ }
+
+ /** Left here for compatibility reasons. */
+ @deprecated("1.3.0", "use toDataFrame")
+ def toSchemaRDD: DataFrame = this
+
+ /**
+ * Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
+ */
+ def toDF: DataFrame = this
+
+ /** Return the schema of this [[DataFrame]]. */
+ override def schema: StructType = queryExecution.analyzed.schema
+
+ /** Return all column names and their data types as an array. */
+ override def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
+
+ /** Return all column names as an array. */
+ override def columns: Array[String] = schema.fields.map(_.name)
+
+ /** Print the schema to the console in a nice tree format. */
+ override def printSchema(): Unit = println(schema.treeString)
+
+ /**
+ * Cartesian join with another [[DataFrame]].
+ *
+ * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ *
+ * @param right Right side of the join operation.
+ */
+ override def join(right: DataFrame): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ }
+
+ /**
+ * Inner join with another [[DataFrame]], using the given join expression.
+ *
+ * {{{
+ * // The following two are equivalent:
+ * df1.join(df2, $"df1Key" === $"df2Key")
+ * df1.join(df2).where($"df1Key" === $"df2Key")
+ * }}}
+ */
+ override def join(right: DataFrame, joinExprs: Column): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
+ }
+
+ /**
+ * Join with another [[DataFrame]], usin g the given join expression. The following performs
+ * a full outer join between `df1` and `df2`.
+ *
+ * {{{
+ * df1.join(df2, "outer", $"df1Key" === $"df2Key")
+ * }}}
+ *
+ * @param right Right side of the join.
+ * @param joinExprs Join expression.
+ * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+ */
+ override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the specified column, in ascending column.
+ * {{{
+ * // The following 3 are equivalent
+ * df.sort("sortcol")
+ * df.sort($"sortcol")
+ * df.sort($"sortcol".asc)
+ * }}}
+ */
+ override def sort(colName: String): DataFrame = {
+ Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the given expressions. For example:
+ * {{{
+ * df.sort($"col1", $"col2".desc)
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = true, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the given expressions.
+ * This is an alias of the `sort` function.
+ */
+ @scala.annotation.varargs
+ override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ sort(sortExpr, sortExprs :_*)
+ }
+
+ /**
+ * Selecting a single column and return it as a [[Column]].
+ */
+ override def apply(colName: String): Column = {
+ val expr = resolve(colName)
+ new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr)
+ }
+
+ /**
+ * Selecting a set of expressions, wrapped in a Product.
+ * {{{
+ * // The following two are equivalent:
+ * df.apply(($"colA", $"colB" + 1))
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ */
+ override def apply(projection: Product): DataFrame = {
+ require(projection.productArity >= 1)
+ select(projection.productIterator.map {
+ case c: Column => c
+ case o: Any => new Column(Some(sqlContext), None, LiteralExpr(o))
+ }.toSeq :_*)
+ }
+
+ /**
+ * Alias the current [[DataFrame]].
+ */
+ override def as(name: String): DataFrame = Subquery(name, logicalPlan)
+
+ /**
+ * Selecting a set of expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def select(cols: Column*): DataFrame = {
+ val exprs = cols.zipWithIndex.map {
+ case (Column(expr: NamedExpression), _) =>
+ expr
+ case (Column(expr: Expression), _) =>
+ Alias(expr, expr.toString)()
+ }
+ Project(exprs.toSeq, logicalPlan)
+ }
+
+ /**
+ * Selecting a set of columns. This is a variant of `select` that can only select
+ * existing columns using column names (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // The following two are equivalent:
+ * df.select("colA", "colB")
+ * df.select($"colA", $"colB")
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def select(col: String, cols: String*): DataFrame = {
+ select((col +: cols).map(new Column(_)) :_*)
+ }
+
+ /**
+ * Filtering rows using the given condition.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def filter(condition: Column): DataFrame = {
+ Filter(condition.expr, logicalPlan)
+ }
+
+ /**
+ * Filtering rows using the given condition. This is an alias for `filter`.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def where(condition: Column): DataFrame = filter(condition)
+
+ /**
+ * Filtering rows using the given condition. This is a shorthand meant for Scala.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def apply(condition: Column): DataFrame = filter(condition)
+
+ /**
+ * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ * See [[GroupedDataFrame]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * df.groupBy($"department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * df.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def groupBy(cols: Column*): GroupedDataFrame = {
+ new GroupedDataFrame(this, cols.map(_.expr))
+ }
+
+ /**
+ * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ * See [[GroupedDataFrame]] for all the available aggregate functions.
+ *
+ * This is a variant of groupBy that can only group by existing columns using column names
+ * (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * df.groupBy("department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * df.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
+ val colNames: Seq[String] = col1 +: cols
+ new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
+ }
+
+ /**
+ * Aggregate on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(Map("age" -> "max", "salary" -> "avg"))
+ * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}
+ */
+ override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
+
+ /**
+ * Aggregate on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(max($"age"), avg($"salary"))
+ * df.groupBy().agg(max($"age"), avg($"salary"))
+ * }}
+ */
+ @scala.annotation.varargs
+ override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
+
+ /**
+ * Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function
+ * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
+ */
+ override def limit(n: Int): DataFrame = Limit(LiteralExpr(n), logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing union of rows in this frame and another frame.
+ * This is equivalent to `UNION ALL` in SQL.
+ */
+ override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing rows only in both this frame and another frame.
+ * This is equivalent to `INTERSECT` in SQL.
+ */
+ override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing rows in this frame but not in another frame.
+ * This is equivalent to `EXCEPT` in SQL.
+ */
+ override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] by sampling a fraction of rows.
+ *
+ * @param withReplacement Sample with replacement or not.
+ * @param fraction Fraction of rows to generate.
+ * @param seed Seed for sampling.
+ */
+ override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
+ Sample(fraction, withReplacement, seed, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
+ *
+ * @param withReplacement Sample with replacement or not.
+ * @param fraction Fraction of rows to generate.
+ */
+ override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Return a new [[DataFrame]] by adding a column.
+ */
+ override def addColumn(colName: String, col: Column): DataFrame = {
+ select(Column("*"), col.as(colName))
+ }
+
+ /**
+ * Return the first `n` rows.
+ */
+ override def head(n: Int): Array[Row] = limit(n).collect()
+
+ /**
+ * Return the first row.
+ */
+ override def head(): Row = head(1).head
+
+ /**
+ * Return the first row. Alias for head().
+ */
+ override def first(): Row = head()
+
+ override def map[R: ClassTag](f: Row => R): RDD[R] = {
+ rdd.map(f)
+ }
+
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
+
+ /**
+ * Return the first `n` rows in the [[DataFrame]].
+ */
+ override def take(n: Int): Array[Row] = head(n)
+
+ /**
+ * Return an array that contains all of [[Row]]s in this [[DataFrame]].
+ */
+ override def collect(): Array[Row] = rdd.collect()
+
+ /**
+ * Return a Java list that contains all of [[Row]]s in this [[DataFrame]].
+ */
+ override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
+
+ /**
+ * Return the number of rows in the [[DataFrame]].
+ */
+ override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
+
+ /**
+ * Return a new [[DataFrame]] that has exactly `numPartitions` partitions.
+ */
+ override def repartition(numPartitions: Int): DataFrame = {
+ sqlContext.applySchema(rdd.repartition(numPartitions), schema)
+ }
+
+ override def persist(): this.type = {
+ sqlContext.cacheQuery(this)
+ this
+ }
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheQuery(this, None, newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.tryUncacheQuery(this, blocking)
+ this
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // I/O
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
+ */
+ override def rdd: RDD[Row] = {
+ val schema = this.schema
+ queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
+ }
+
+ /**
+ * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ * table is tied to the [[SQLContext]] that was used to create this DataFrame.
+ *
+ * @group schema
+ */
+ override def registerTempTable(tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(this, tableName)
+ }
+
+ /**
+ * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema.
+ * Files that are written out using this method can be read back in as a [[DataFrame]]
+ * using the `parquetFile` function in [[SQLContext]].
+ */
+ override def saveAsParquetFile(path: String): Unit = {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table from the the contents of this DataFrame. This will fail if the table already
+ * exists.
+ *
+ * Note that this currently only works with DataFrame that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ override def saveAsTable(tableName: String): Unit = {
+ sqlContext.executePlan(
+ CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
+ }
+
+ /**
+ * :: Experimental ::
+ * Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
+ */
+ @Experimental
+ override def insertInto(tableName: String, overwrite: Boolean): Unit = {
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
+ }
+
+ /**
+ * Return the content of the [[DataFrame]] as a RDD of JSON strings.
+ */
+ override def toJSON: RDD[String] = {
+ val rowSchema = this.schema
+ this.mapPartitions { iter =>
+ val jsonFactory = new JsonFactory()
+ iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
+ }
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+ // for Python API
+ ////////////////////////////////////////////////////////////////////////////
+ /**
+ * A helpful function for Py4j, convert a list of Column to an array
+ */
+ protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = {
+ cols.toList.toArray
+ }
+
+ /**
+ * Converts a JavaRDD to a PythonRDD.
+ */
+ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ SerDeUtil.javaToPython(jrdd)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
new file mode 100644
index 0000000000000..1f1e9bd9899f6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+
+
+/**
+ * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
+ */
+class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
+ extends GroupedDataFrameApi {
+
+ private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
+ val namedGroupingExprs = groupingExprs.map {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.toString)()
+ }
+ new DataFrame(df.sqlContext,
+ Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
+ }
+
+ private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
+ df.numericColumns.map { c =>
+ val a = f(c)
+ Alias(a, a.toString)()
+ }
+ }
+
+ private[this] def strToExpr(expr: String): (Expression => Expression) = {
+ expr.toLowerCase match {
+ case "avg" | "average" | "mean" => Average
+ case "max" => Max
+ case "min" => Min
+ case "sum" => Sum
+ case "count" | "size" => Count
+ }
+ }
+
+ /**
+ * Compute aggregates by specifying a map from column name to aggregate methods.
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max"
+ * "sum" -> "expense"
+ * ))
+ * }}}
+ */
+ override def agg(exprs: Map[String, String]): DataFrame = {
+ exprs.map { case (colName, expr) =>
+ val a = strToExpr(expr)(df(colName).expr)
+ Alias(a, a.toString)()
+ }.toSeq
+ }
+
+ /**
+ * Compute aggregates by specifying a map from column name to aggregate methods.
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max"
+ * "sum" -> "expense"
+ * ))
+ * }}}
+ */
+ def agg(exprs: java.util.Map[String, String]): DataFrame = {
+ agg(exprs.toMap)
+ }
+
+ /**
+ * Compute aggregates by specifying a series of aggregate columns.
+ * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]].
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * import org.apache.spark.sql.dsl._
+ * df.groupBy("department").agg(max($"age"), sum($"expense"))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def agg(expr: Column, exprs: Column*): DataFrame = {
+ val aggExprs = (expr +: exprs).map(_.expr).map {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.toString)()
+ }
+
+ new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
+ }
+
+ /** Count the number of rows for each group. */
+ override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
+
+ /**
+ * Compute the average value for each numeric columns for each group. This is an alias for `avg`.
+ */
+ override def mean(): DataFrame = aggregateNumericColumns(Average)
+
+ /**
+ * Compute the max value for each numeric columns for each group.
+ */
+ override def max(): DataFrame = aggregateNumericColumns(Max)
+
+ /**
+ * Compute the mean value for each numeric columns for each group.
+ */
+ override def avg(): DataFrame = aggregateNumericColumns(Average)
+
+ /**
+ * Compute the min value for each numeric column for each group.
+ */
+ override def min(): DataFrame = aggregateNumericColumns(Min)
+
+ /**
+ * Compute the sum for each numeric columns for each group.
+ */
+ override def sum(): DataFrame = aggregateNumericColumns(Sum)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala
new file mode 100644
index 0000000000000..08cd4d0f3f009
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.types._
+
+object Literal {
+
+ /** Return a new boolean literal. */
+ def apply(literal: Boolean): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new byte literal. */
+ def apply(literal: Byte): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new short literal. */
+ def apply(literal: Short): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new int literal. */
+ def apply(literal: Int): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new long literal. */
+ def apply(literal: Long): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new float literal. */
+ def apply(literal: Float): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new double literal. */
+ def apply(literal: Double): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new string literal. */
+ def apply(literal: String): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new decimal literal. */
+ def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new decimal literal. */
+ def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new timestamp literal. */
+ def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new date literal. */
+ def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new binary (byte array) literal. */
+ def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new null literal. */
+ def apply(literal: Null): Column = new Column(LiteralExpr(null))
+
+ /**
+ * Return a Column expression representing the literal value. Throws an exception if the
+ * data type is not supported by SparkSQL.
+ */
+ protected[sql] def anyToLiteral(literal: Any): Column = {
+ // If the literal is a symbol, convert it into a Column.
+ if (literal.isInstanceOf[Symbol]) {
+ return dsl.symbolToColumn(literal.asInstanceOf[Symbol])
+ }
+
+ val literalExpr = literal match {
+ case v: Int => LiteralExpr(v, IntegerType)
+ case v: Long => LiteralExpr(v, LongType)
+ case v: Double => LiteralExpr(v, DoubleType)
+ case v: Float => LiteralExpr(v, FloatType)
+ case v: Byte => LiteralExpr(v, ByteType)
+ case v: Short => LiteralExpr(v, ShortType)
+ case v: String => LiteralExpr(v, StringType)
+ case v: Boolean => LiteralExpr(v, BooleanType)
+ case v: BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
+ case v: java.math.BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
+ case v: Decimal => LiteralExpr(v, DecimalType.Unlimited)
+ case v: java.sql.Timestamp => LiteralExpr(v, TimestampType)
+ case v: java.sql.Date => LiteralExpr(v, DateType)
+ case v: Array[Byte] => LiteralExpr(v, BinaryType)
+ case null => LiteralExpr(null, NullType)
+ case _ =>
+ throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
+ }
+ new Column(literalExpr)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 0a22968cc7807..5030e689c36ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -43,7 +42,7 @@ import org.apache.spark.util.Utils
/**
* :: AlphaComponent ::
- * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]]
+ * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]]
* objects and the execution of SQL queries.
*
* @groupname userf Spark SQL Functions
@@ -53,7 +52,6 @@ import org.apache.spark.util.Utils
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with CacheManager
- with ExpressionConversions
with Serializable {
self =>
@@ -111,8 +109,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
- protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+
+ protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan)
sparkContext.getConf.getAll.foreach {
case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
@@ -124,24 +122,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = {
+ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val attributeSeq = ScalaReflection.attributesFor[A]
val schema = StructType.fromAttributes(attributeSeq)
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
- new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
+ new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
}
/**
- * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]].
+ * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
*/
- def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {
- new SchemaRDD(this, LogicalRelation(baseRelation))
+ def baseRelationToSchemaRDD(baseRelation: BaseRelation): DataFrame = {
+ new DataFrame(this, LogicalRelation(baseRelation))
}
/**
* :: DeveloperApi ::
- * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
@@ -170,11 +168,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
- new SchemaRDD(this, logicalPlan)
+ new DataFrame(this, logicalPlan)
}
/**
@@ -183,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = {
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -201,7 +199,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
) : Row
}
}
- new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this))
+ new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
}
/**
@@ -210,35 +208,35 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = {
+ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
applySchema(rdd.rdd, beanClass)
}
/**
- * Loads a Parquet file, returning the result as a [[SchemaRDD]].
+ * Loads a Parquet file, returning the result as a [[DataFrame]].
*
* @group userf
*/
- def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
+ def parquetFile(path: String): DataFrame =
+ new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
- * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]].
+ * Loads a JSON file (one object per line), returning the result as a [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
- def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0)
+ def jsonFile(path: String): DataFrame = jsonFile(path, 1.0)
/**
* :: Experimental ::
* Loads a JSON file (one object per line) and applies the given schema,
- * returning the result as a [[SchemaRDD]].
+ * returning the result as a [[DataFrame]].
*
* @group userf
*/
@Experimental
- def jsonFile(path: String, schema: StructType): SchemaRDD = {
+ def jsonFile(path: String, schema: StructType): DataFrame = {
val json = sparkContext.textFile(path)
jsonRDD(json, schema)
}
@@ -247,29 +245,29 @@ class SQLContext(@transient val sparkContext: SparkContext)
* :: Experimental ::
*/
@Experimental
- def jsonFile(path: String, samplingRatio: Double): SchemaRDD = {
+ def jsonFile(path: String, samplingRatio: Double): DataFrame = {
val json = sparkContext.textFile(path)
jsonRDD(json, samplingRatio)
}
/**
* Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
- * [[SchemaRDD]].
+ * [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
- def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0)
+ def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0)
/**
* :: Experimental ::
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
- * returning the result as a [[SchemaRDD]].
+ * returning the result as a [[DataFrame]].
*
* @group userf
*/
@Experimental
- def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
+ def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
@@ -283,7 +281,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* :: Experimental ::
*/
@Experimental
- def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
+ def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
@@ -298,8 +296,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
- catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
+ def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = {
+ catalog.registerTable(Seq(tableName), rdd.logicalPlan)
}
/**
@@ -321,17 +319,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- def sql(sqlText: String): SchemaRDD = {
+ def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
- new SchemaRDD(this, parseSql(sqlText))
+ new DataFrame(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
}
/** Returns the specified table as a SchemaRDD */
- def table(tableName: String): SchemaRDD =
- new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))
+ def table(tableName: String): DataFrame =
+ new DataFrame(this, catalog.lookupRelation(Seq(tableName)))
/**
* A collection of methods that are considered experimental, but can be used to hook into
@@ -454,15 +452,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
* access to the intermediate phases of query execution for developers.
*/
@DeveloperApi
- protected abstract class QueryExecution {
- def logical: LogicalPlan
+ protected class QueryExecution(val logical: LogicalPlan) {
- lazy val analyzed = ExtractPythonUdfs(analyzer(logical))
- lazy val withCachedData = useCachedData(analyzed)
- lazy val optimizedPlan = optimizer(withCachedData)
+ lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical))
+ lazy val withCachedData: LogicalPlan = useCachedData(analyzed)
+ lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
// TODO: Don't just pick the first one...
- lazy val sparkPlan = {
+ lazy val sparkPlan: SparkPlan = {
SparkPlan.currentContext.set(self)
planner(optimizedPlan).next()
}
@@ -512,7 +509,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
- schemaString: String): SchemaRDD = {
+ schemaString: String): DataFrame = {
val schema = parseDataType(schemaString).asInstanceOf[StructType]
applySchemaToPythonRDD(rdd, schema)
}
@@ -522,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
- schema: StructType): SchemaRDD = {
+ schema: StructType): DataFrame = {
def needsConversion(dataType: DataType): Boolean = dataType match {
case ByteType => true
@@ -549,7 +546,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
iter.map { m => new GenericRow(m): Row}
}
- new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
+ new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
deleted file mode 100644
index d1e21dffeb8c5..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ /dev/null
@@ -1,511 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import java.util.{List => JList}
-
-import scala.collection.JavaConversions._
-
-import com.fasterxml.jackson.core.JsonFactory
-
-import net.razorvine.pickle.Pickler
-
-import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
-import org.apache.spark.annotation.{AlphaComponent, Experimental}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{BooleanType, StructType}
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: AlphaComponent ::
- * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions,
- * SchemaRDDs can be used in relational queries, as shown in the examples below.
- *
- * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD
- * whose elements are scala case classes into a SchemaRDD. This conversion can also be done
- * explicitly using the `createSchemaRDD` function on a [[SQLContext]].
- *
- * A `SchemaRDD` can also be created by loading data in from external sources.
- * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]]
- * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]].
- *
- * == SQL Queries ==
- * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once
- * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements.
- *
- * {{{
- * // One method for defining the schema of an RDD is to make a case class with the desired column
- * // names and types.
- * case class Record(key: Int, value: String)
- *
- * val sc: SparkContext // An existing spark context.
- * val sqlContext = new SQLContext(sc)
- *
- * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
- * import sqlContext._
- *
- * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
- * // Any RDD containing case classes can be registered as a table. The schema of the table is
- * // automatically inferred using scala reflection.
- * rdd.registerTempTable("records")
- *
- * val results: SchemaRDD = sql("SELECT * FROM records")
- * }}}
- *
- * == Language Integrated Queries ==
- *
- * {{{
- *
- * case class Record(key: Int, value: String)
- *
- * val sc: SparkContext // An existing spark context.
- * val sqlContext = new SQLContext(sc)
- *
- * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
- * import sqlContext._
- *
- * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i)))
- *
- * // Example of language integrated queries.
- * rdd.where('key === 1).orderBy('value.asc).select('key).collect()
- * }}}
- *
- * @groupname Query Language Integrated Queries
- * @groupdesc Query Functions that create new queries from SchemaRDDs. The
- * result of all query functions is also a SchemaRDD, allowing multiple operations to be
- * chained using a builder pattern.
- * @groupprio Query -2
- * @groupname schema SchemaRDD Functions
- * @groupprio schema -1
- * @groupname Ungrouped Base RDD Functions
- */
-@AlphaComponent
-class SchemaRDD(
- @transient val sqlContext: SQLContext,
- @transient val baseLogicalPlan: LogicalPlan)
- extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike {
-
- def baseSchemaRDD = this
-
- // =========================================================================================
- // RDD functions: Copy the internal row representation so we present immutable data to users.
- // =========================================================================================
-
- override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))
-
- override def getPartitions: Array[Partition] = firstParent[Row].partitions
-
- override protected def getDependencies: Seq[Dependency[_]] = {
- schema // Force reification of the schema so it is available on executors.
-
- List(new OneToOneDependency(queryExecution.toRdd))
- }
-
- /**
- * Returns the schema of this SchemaRDD (represented by a [[StructType]]).
- *
- * @group schema
- */
- lazy val schema: StructType = queryExecution.analyzed.schema
-
- /**
- * Returns a new RDD with each row transformed to a JSON string.
- *
- * @group schema
- */
- def toJSON: RDD[String] = {
- val rowSchema = this.schema
- this.mapPartitions { iter =>
- val jsonFactory = new JsonFactory()
- iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
- }
- }
-
-
- // =======================================================================
- // Query DSL
- // =======================================================================
-
- /**
- * Changes the output of this relation to the given expressions, similar to the `SELECT` clause
- * in SQL.
- *
- * {{{
- * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName)
- * }}}
- *
- * @param exprs a set of logical expression that will be evaluated for each input row.
- *
- * @group Query
- */
- def select(exprs: Expression*): SchemaRDD = {
- val aliases = exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"c$i")()
- }
- new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
- }
-
- /**
- * Filters the output, only returning those rows where `condition` evaluates to true.
- *
- * {{{
- * schemaRDD.where('a === 'b)
- * schemaRDD.where('a === 1)
- * schemaRDD.where('a + 'b > 10)
- * }}}
- *
- * @group Query
- */
- def where(condition: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Filter(condition, logicalPlan))
-
- /**
- * Performs a relational join on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be joined with this one.
- * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.`
- * @param on An optional condition for the join operation. This is equivalent to the `ON`
- * clause in standard SQL. In the case of `Inner` joins, specifying a
- * `condition` is equivalent to adding `where` clauses after the `join`.
- *
- * @group Query
- */
- def join(
- otherPlan: SchemaRDD,
- joinType: JoinType = Inner,
- on: Option[Expression] = None): SchemaRDD =
- new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on))
-
- /**
- * Sorts the results by the given expressions.
- * {{{
- * schemaRDD.orderBy('a)
- * schemaRDD.orderBy('a, 'b)
- * schemaRDD.orderBy('a.asc, 'b.desc)
- * }}}
- *
- * @group Query
- */
- def orderBy(sortExprs: SortOrder*): SchemaRDD =
- new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan))
-
- /**
- * Sorts the results by the given expressions within partition.
- * {{{
- * schemaRDD.sortBy('a)
- * schemaRDD.sortBy('a, 'b)
- * schemaRDD.sortBy('a.asc, 'b.desc)
- * }}}
- *
- * @group Query
- */
- def sortBy(sortExprs: SortOrder*): SchemaRDD =
- new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan))
-
- @deprecated("use limit with integer argument", "1.1.0")
- def limit(limitExpr: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
-
- /**
- * Limits the results by the given integer.
- * {{{
- * schemaRDD.limit(10)
- * }}}
- * @group Query
- */
- def limit(limitNum: Int): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
-
- /**
- * Performs a grouping followed by an aggregation.
- *
- * {{{
- * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales)
- * }}}
- *
- * @group Query
- */
- def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = {
- val aliasedExprs = aggregateExprs.map {
- case ne: NamedExpression => ne
- case e => Alias(e, e.toString)()
- }
- new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
- }
-
- /**
- * Performs an aggregation over all Rows in this RDD.
- * This is equivalent to a groupBy with no grouping expressions.
- *
- * {{{
- * schemaRDD.aggregate(Sum('sales) as 'totalSales)
- * }}}
- *
- * @group Query
- */
- def aggregate(aggregateExprs: Expression*): SchemaRDD = {
- groupBy()(aggregateExprs: _*)
- }
-
- /**
- * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
- * with the same name, for example, when performing self-joins.
- *
- * {{{
- * val x = schemaRDD.where('a === 1).as('x)
- * val y = schemaRDD.where('a === 2).as('y)
- * x.join(y).where("x.a".attr === "y.a".attr),
- * }}}
- *
- * @group Query
- */
- def as(alias: Symbol) =
- new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan))
-
- /**
- * Combines the tuples of two RDDs with the same schema, keeping duplicates.
- *
- * @group Query
- */
- def unionAll(otherPlan: SchemaRDD) =
- new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Performs a relational except on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be excepted from this one.
- *
- * @group Query
- */
- def except(otherPlan: SchemaRDD): SchemaRDD =
- new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Performs a relational intersect on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be intersected with this one.
- *
- * @group Query
- */
- def intersect(otherPlan: SchemaRDD): SchemaRDD =
- new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Filters tuples using a function over the value of the specified column.
- *
- * {{{
- * schemaRDD.where('a)((a: Int) => ...)
- * }}}
- *
- * @group Query
- */
- def where[T1](arg1: Symbol)(udf: (T1) => Boolean) =
- new SchemaRDD(
- sqlContext,
- Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan))
-
- /**
- * :: Experimental ::
- * Returns a sampled version of the underlying dataset.
- *
- * @group Query
- */
- @Experimental
- override
- def sample(
- withReplacement: Boolean = true,
- fraction: Double,
- seed: Long) =
- new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))
-
- /**
- * :: Experimental ::
- * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this
- * implementation leverages the query optimizer to compute the count on the SchemaRDD, which
- * supports features such as filter pushdown.
- *
- * @group Query
- */
- @Experimental
- override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
-
- /**
- * :: Experimental ::
- * Applies the given Generator, or table generating function, to this relation.
- *
- * @param generator A table generating function. The API for such functions is likely to change
- * in future releases
- * @param join when set to true, each output row of the generator is joined with the input row
- * that produced it.
- * @param outer when set to true, at least one row will be produced for each input row, similar to
- * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a
- * given row, a single row will be output, with `NULL` values for each of the
- * generated columns.
- * @param alias an optional alias that can be used as qualifier for the attributes that are
- * produced by this generate operation.
- *
- * @group Query
- */
- @Experimental
- def generate(
- generator: Generator,
- join: Boolean = false,
- outer: Boolean = false,
- alias: Option[String] = None) =
- new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan))
-
- /**
- * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit
- * conversion from a standard RDD to a SchemaRDD.
- *
- * @group schema
- */
- def toSchemaRDD = this
-
- /**
- * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
- */
- private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
- }
-
- /**
- * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
- * format as javaToPython. It is used by pyspark.
- */
- private[sql] def collectToPython: JList[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val pickle = new Pickler
- new java.util.ArrayList(collect().map { row =>
- EvaluatePython.rowToArray(row, fieldTypes)
- }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
- }
-
- /**
- * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
- * format as javaToPython and collectToPython. It is used by pyspark.
- */
- private[sql] def takeSampleToPython(
- withReplacement: Boolean,
- num: Int,
- seed: Long): JList[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val pickle = new Pickler
- new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
- EvaluatePython.rowToArray(row, fieldTypes)
- }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
- }
-
- /**
- * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
- * of base RDD functions that do not change schema.
- *
- * @param rdd RDD derived from this one and has same schema
- *
- * @group schema
- */
- private def applySchema(rdd: RDD[Row]): SchemaRDD = {
- new SchemaRDD(sqlContext,
- LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext))
- }
-
- // =======================================================================
- // Overridden RDD actions
- // =======================================================================
-
- override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
-
- def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*)
-
- override def take(num: Int): Array[Row] = limit(num).collect()
-
- // =======================================================================
- // Base RDD functions that do NOT change schema
- // =======================================================================
-
- // Transformations (return a new RDD)
-
- override def coalesce(numPartitions: Int, shuffle: Boolean = false)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.coalesce(numPartitions, shuffle)(ord))
-
- override def distinct(): SchemaRDD = applySchema(super.distinct())
-
- override def distinct(numPartitions: Int)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.distinct(numPartitions)(ord))
-
- def distinct(numPartitions: Int): SchemaRDD =
- applySchema(super.distinct(numPartitions)(null))
-
- override def filter(f: Row => Boolean): SchemaRDD =
- applySchema(super.filter(f))
-
- override def intersection(other: RDD[Row]): SchemaRDD =
- applySchema(super.intersection(other))
-
- override def intersection(other: RDD[Row], partitioner: Partitioner)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.intersection(other, partitioner)(ord))
-
- override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD =
- applySchema(super.intersection(other, numPartitions))
-
- override def repartition(numPartitions: Int)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.repartition(numPartitions)(ord))
-
- override def subtract(other: RDD[Row]): SchemaRDD =
- applySchema(super.subtract(other))
-
- override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD =
- applySchema(super.subtract(other, numPartitions))
-
- override def subtract(other: RDD[Row], p: Partitioner)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.subtract(other, p)(ord))
-
- /** Overridden cache function will always use the in-memory columnar caching. */
- override def cache(): this.type = {
- sqlContext.cacheQuery(this)
- this
- }
-
- override def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheQuery(this, None, newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean): this.type = {
- sqlContext.tryUncacheQuery(this, blocking)
- this
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
deleted file mode 100644
index 3cf9209465b76..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.LogicalRDD
-
-/**
- * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
- */
-private[sql] trait SchemaRDDLike {
- @transient def sqlContext: SQLContext
- @transient val baseLogicalPlan: LogicalPlan
-
- private[sql] def baseSchemaRDD: SchemaRDD
-
- /**
- * :: DeveloperApi ::
- * A lazily computed query execution workflow. All other RDD operations are passed
- * through to the RDD that is produced by this workflow. This workflow is produced lazily because
- * invoking the whole query optimization pipeline can be expensive.
- *
- * The query execution is considered a Developer API as phases may be added or removed in future
- * releases. This execution is only exposed to provide an interface for inspecting the various
- * phases for debugging purposes. Applications should not depend on particular phases existing
- * or producing any specific output, even for exactly the same query.
- *
- * Additionally, the RDD exposed by this execution is not designed for consumption by end users.
- * In particular, it does not contain any schema information, and it reuses Row objects
- * internally. This object reuse improves performance, but can make programming against the RDD
- * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly.
- */
- @transient
- @DeveloperApi
- lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
-
- @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- baseLogicalPlan
- }
-
- override def toString =
- s"""${super.toString}
- |== Query Plan ==
- |${queryExecution.simpleString}""".stripMargin.trim
-
- /**
- * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that
- * are written out using this method can be read back in as a SchemaRDD using the `parquetFile`
- * function.
- *
- * @group schema
- */
- def saveAsParquetFile(path: String): Unit = {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
-
- /**
- * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
- * table is tied to the [[SQLContext]] that was used to create this SchemaRDD.
- *
- * @group schema
- */
- def registerTempTable(tableName: String): Unit = {
- sqlContext.registerRDDAsTable(baseSchemaRDD, tableName)
- }
-
- @deprecated("Use registerTempTable instead of registerAsTable.", "1.1")
- def registerAsTable(tableName: String): Unit = registerTempTable(tableName)
-
- /**
- * :: Experimental ::
- * Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
- *
- * @group schema
- */
- @Experimental
- def insertInto(tableName: String, overwrite: Boolean): Unit =
- sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
- Map.empty, logicalPlan, overwrite)).toRdd
-
- /**
- * :: Experimental ::
- * Appends the rows from this RDD to the specified table.
- *
- * @group schema
- */
- @Experimental
- def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
-
- /**
- * :: Experimental ::
- * Creates a table from the the contents of this SchemaRDD. This will fail if the table already
- * exists.
- *
- * Note that this currently only works with SchemaRDDs that are created from a HiveContext as
- * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
- * an RDD out to a parquet file, and then register that file as a table. This "table" can then
- * be the target of an `insertInto`.
- *
- * @group schema
- */
- @Experimental
- def saveAsTable(tableName: String): Unit =
- sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd
-
- /** Returns the schema as a string in the tree format.
- *
- * @group schema
- */
- def schemaString: String = baseSchemaRDD.schema.treeString
-
- /** Prints out the schema.
- *
- * @group schema
- */
- def printSchema(): Unit = println(schemaString)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
new file mode 100644
index 0000000000000..073d41e938478
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -0,0 +1,289 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * An internal interface defining the RDD-like methods for [[DataFrame]].
+ * Please use [[DataFrame]] directly, and do NOT use this.
+ */
+trait RDDApi[T] {
+
+ def cache(): this.type = persist()
+
+ def persist(): this.type
+
+ def persist(newLevel: StorageLevel): this.type
+
+ def unpersist(): this.type = unpersist(blocking = false)
+
+ def unpersist(blocking: Boolean): this.type
+
+ def map[R: ClassTag](f: T => R): RDD[R]
+
+ def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
+
+ def take(n: Int): Array[T]
+
+ def collect(): Array[T]
+
+ def collectAsList(): java.util.List[T]
+
+ def count(): Long
+
+ def first(): T
+
+ def repartition(numPartitions: Int): DataFrame
+}
+
+
+/**
+ * An internal interface defining data frame related methods in [[DataFrame]].
+ * Please use [[DataFrame]] directly, and do NOT use this.
+ */
+trait DataFrameSpecificApi {
+
+ def schema: StructType
+
+ def printSchema(): Unit
+
+ def dtypes: Array[(String, String)]
+
+ def columns: Array[String]
+
+ def head(): Row
+
+ def head(n: Int): Array[Row]
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Relational operators
+ /////////////////////////////////////////////////////////////////////////////
+ def apply(colName: String): Column
+
+ def apply(projection: Product): DataFrame
+
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def select(col: String, cols: String*): DataFrame
+
+ def apply(condition: Column): DataFrame
+
+ def as(name: String): DataFrame
+
+ def filter(condition: Column): DataFrame
+
+ def where(condition: Column): DataFrame
+
+ @scala.annotation.varargs
+ def groupBy(cols: Column*): GroupedDataFrame
+
+ @scala.annotation.varargs
+ def groupBy(col1: String, cols: String*): GroupedDataFrame
+
+ def agg(exprs: Map[String, String]): DataFrame
+
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame
+
+ def sort(colName: String): DataFrame
+
+ @scala.annotation.varargs
+ def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ def join(right: DataFrame): DataFrame
+
+ def join(right: DataFrame, joinExprs: Column): DataFrame
+
+ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame
+
+ def limit(n: Int): DataFrame
+
+ def unionAll(other: DataFrame): DataFrame
+
+ def intersect(other: DataFrame): DataFrame
+
+ def except(other: DataFrame): DataFrame
+
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
+
+ def sample(withReplacement: Boolean, fraction: Double): DataFrame
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Column mutation
+ /////////////////////////////////////////////////////////////////////////////
+ def addColumn(colName: String, col: Column): DataFrame
+
+ /////////////////////////////////////////////////////////////////////////////
+ // I/O and interaction with other frameworks
+ /////////////////////////////////////////////////////////////////////////////
+
+ def rdd: RDD[Row]
+
+ def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
+
+ def toJSON: RDD[String]
+
+ def registerTempTable(tableName: String): Unit
+
+ def saveAsParquetFile(path: String): Unit
+
+ @Experimental
+ def saveAsTable(tableName: String): Unit
+
+ @Experimental
+ def insertInto(tableName: String, overwrite: Boolean): Unit
+
+ @Experimental
+ def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Stat functions
+ /////////////////////////////////////////////////////////////////////////////
+// def describe(): Unit
+//
+// def mean(): Unit
+//
+// def max(): Unit
+//
+// def min(): Unit
+}
+
+
+/**
+ * An internal interface defining expression APIs for [[DataFrame]].
+ * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this.
+ */
+trait ExpressionApi {
+
+ def isComputable: Boolean
+
+ def unary_- : Column
+ def unary_! : Column
+ def unary_~ : Column
+
+ def + (other: Column): Column
+ def + (other: Any): Column
+ def - (other: Column): Column
+ def - (other: Any): Column
+ def * (other: Column): Column
+ def * (other: Any): Column
+ def / (other: Column): Column
+ def / (other: Any): Column
+ def % (other: Column): Column
+ def % (other: Any): Column
+ def & (other: Column): Column
+ def & (other: Any): Column
+ def | (other: Column): Column
+ def | (other: Any): Column
+ def ^ (other: Column): Column
+ def ^ (other: Any): Column
+
+ def && (other: Column): Column
+ def && (other: Boolean): Column
+ def || (other: Column): Column
+ def || (other: Boolean): Column
+
+ def < (other: Column): Column
+ def < (other: Any): Column
+ def <= (other: Column): Column
+ def <= (other: Any): Column
+ def > (other: Column): Column
+ def > (other: Any): Column
+ def >= (other: Column): Column
+ def >= (other: Any): Column
+ def === (other: Column): Column
+ def === (other: Any): Column
+ def equalTo(other: Column): Column
+ def equalTo(other: Any): Column
+ def <=> (other: Column): Column
+ def <=> (other: Any): Column
+ def !== (other: Column): Column
+ def !== (other: Any): Column
+
+ @scala.annotation.varargs
+ def in(list: Column*): Column
+
+ def like(other: Column): Column
+ def like(other: String): Column
+ def rlike(other: Column): Column
+ def rlike(other: String): Column
+
+ def contains(other: Column): Column
+ def contains(other: Any): Column
+ def startsWith(other: Column): Column
+ def startsWith(other: String): Column
+ def endsWith(other: Column): Column
+ def endsWith(other: String): Column
+
+ def substr(startPos: Column, len: Column): Column
+ def substr(startPos: Int, len: Int): Column
+
+ def isNull: Column
+ def isNotNull: Column
+
+ def getItem(ordinal: Column): Column
+ def getItem(ordinal: Int): Column
+ def getField(fieldName: String): Column
+
+ def cast(to: DataType): Column
+
+ def asc: Column
+ def desc: Column
+
+ def as(alias: String): Column
+}
+
+
+/**
+ * An internal interface defining aggregation APIs for [[DataFrame]].
+ * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this.
+ */
+trait GroupedDataFrameApi {
+
+ def agg(exprs: Map[String, String]): DataFrame
+
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame
+
+ def avg(): DataFrame
+
+ def mean(): DataFrame
+
+ def min(): DataFrame
+
+ def max(): DataFrame
+
+ def sum(): DataFrame
+
+ def count(): DataFrame
+
+ // TODO: Add var, std
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
new file mode 100644
index 0000000000000..29c3d26ae56d9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Timestamp, Date}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.DataType
+
+
+package object dsl {
+
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) extends AnyVal {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
+ }
+
+ private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
+
+ def sum(e: Column): Column = Sum(e.expr)
+ def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+ def count(e: Column): Column = Count(e.expr)
+
+ @scala.annotation.varargs
+ def countDistinct(expr: Column, exprs: Column*): Column =
+ CountDistinct((expr +: exprs).map(_.expr))
+
+ def avg(e: Column): Column = Average(e.expr)
+ def first(e: Column): Column = First(e.expr)
+ def last(e: Column): Column = Last(e.expr)
+ def min(e: Column): Column = Min(e.expr)
+ def max(e: Column): Column = Max(e.expr)
+ def upper(e: Column): Column = Upper(e.expr)
+ def lower(e: Column): Column = Lower(e.expr)
+ def sqrt(e: Column): Column = Sqrt(e.expr)
+ def abs(e: Column): Column = Abs(e.expr)
+
+ // scalastyle:off
+
+ object literals {
+
+ implicit def booleanToLiteral(b: Boolean): Column = Literal(b)
+
+ implicit def byteToLiteral(b: Byte): Column = Literal(b)
+
+ implicit def shortToLiteral(s: Short): Column = Literal(s)
+
+ implicit def intToLiteral(i: Int): Column = Literal(i)
+
+ implicit def longToLiteral(l: Long): Column = Literal(l)
+
+ implicit def floatToLiteral(f: Float): Column = Literal(f)
+
+ implicit def doubleToLiteral(d: Double): Column = Literal(d)
+
+ implicit def stringToLiteral(s: String): Column = Literal(s)
+
+ implicit def dateToLiteral(d: Date): Column = Literal(d)
+
+ implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying())
+
+ implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d)
+
+ implicit def timestampToLiteral(t: Timestamp): Column = Literal(t)
+
+ implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a)
+ }
+
+
+ /* Use the following code to generate:
+ (0 to 22).map { x =>
+ val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
+ val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf))
+ }""")
+ }
+
+ (0 to 22).map { x =>
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val fTypes = Seq.fill(x + 1)("_").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, returnType, Seq($argsInUdf))
+ }""")
+ }
+ }
+ */
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag](f: Function0[RT]): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function0[_], returnType: DataType): Column = {
+ ScalaUdf(f, returnType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 52a31f01a4358..6fba76c52171b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
import org.apache.spark.sql.catalyst.plans.logical
@@ -137,7 +137,9 @@ case class CacheTableCommand(
isLazy: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName))
+ plan.foreach { logicalPlan =>
+ sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName)
+ }
sqlContext.cacheTable(tableName)
if (!isLazy) {
@@ -159,7 +161,7 @@ case class CacheTableCommand(
case class UncacheTableCommand(tableName: String) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- sqlContext.table(tableName).unpersist()
+ sqlContext.table(tableName).unpersist(blocking = false)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 4d7e338e8ed13..aeb0960e87f14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
-import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -42,7 +42,7 @@ package object debug {
* Augments SchemaRDDs with debug methods.
*/
@DeveloperApi
- implicit class DebugQuery(query: SchemaRDD) {
+ implicit class DebugQuery(query: DataFrame) {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[TreeNodeRef]()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 6dd39be807037..7c49b5220d607 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -37,5 +37,5 @@ package object sql {
* Converts a logical plan into zero or more SparkPlans.
*/
@DeveloperApi
- type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+ protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 02ce1b3e6d811..0b312ef51daa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.util.Utils
@@ -100,7 +100,7 @@ trait ParquetTest {
*/
protected def withParquetRDD[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
- (f: SchemaRDD => Unit): Unit = {
+ (f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(parquetFile(path)))
}
@@ -120,7 +120,7 @@ trait ParquetTest {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetRDD(data) { rdd =>
- rdd.registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(rdd, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 37853d4d03019..d13f2ce2a5e1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -18,19 +18,18 @@
package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
-import org.apache.spark.sql._
+import org.apache.spark.sql.{Row, Strategy}
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
private[sql] object DataSourceStrategy extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
l,
@@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy {
}
}
+ /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */
protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v)
- case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v)
+ case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v)
+ case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v)
- case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v)
- case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v)
+ case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v)
+ case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v)
- case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v)
- case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
+ case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v)
+ case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
GreaterThanOrEqual(a.name, v)
- case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
+ LessThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
+ GreaterThanOrEqual(a.name, v)
case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 171b816a26332..b4af91a768efb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{SchemaRDD, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.execution.RunnableCommand
@@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing(
def run(sqlContext: SQLContext) = {
val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
- new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(
+ new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index f9c082216085d..2564c849b87f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.test
import scala.language.implicitConversions
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/** A SQLContext that can be used for local testing. */
@@ -40,8 +40,8 @@ object TestSQLContext
* Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to
* construct SchemaRDD directly out of local data without relying on implicits.
*/
- protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = {
- new SchemaRDD(this, plan)
+ protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
+ new DataFrame(this, plan)
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
index 9ff40471a00af..e5588938ea162 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
@@ -61,7 +61,7 @@ public Integer call(String str) throws Exception {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
assert(result.getInt(0) == 4);
}
@@ -81,7 +81,7 @@ public Integer call(String str1, String str2) throws Exception {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
assert(result.getInt(0) == 9);
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9e96738ac095a..badd00d34b9b1 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -98,8 +98,8 @@ public Row call(Person person) throws Exception {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
- schemaRDD.registerTempTable("people");
+ DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
+ df.registerTempTable("people");
Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect();
List expected = new ArrayList(2);
@@ -147,17 +147,17 @@ public void applySchemaToJSON() {
null,
"this is another simple string."));
- SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
- StructType actualSchema1 = schemaRDD1.schema();
+ DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
+ StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
- schemaRDD1.registerTempTable("jsonTable1");
+ df1.registerTempTable("jsonTable1");
List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
- StructType actualSchema2 = schemaRDD2.schema();
+ DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
+ StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
- schemaRDD2.registerTempTable("jsonTable2");
+ df2.registerTempTable("jsonTable2");
List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index cfc037caff2a9..34763156a6d11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index afbfe214f1ce4..a5848f219cea9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/* Implicits */
-import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import scala.language.postfixOps
@@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
- testData2.groupBy('a)('a, sum('b)),
+ testData2.groupBy("a").agg($"a", sum($"b")),
Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
- testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
+ testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
- testData2.aggregate(sum('b)),
+ testData2.agg(sum('b)),
Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
- testData.where($"key" === 1).select($"value"),
+ testData.where($"key" === Literal(1)).select($"value"),
Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select *") {
checkAnswer(
- testData.select(Star(None)),
+ testData.select($"*"),
testData.collect().toSeq)
}
test("simple select") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select with functions") {
checkAnswer(
- testData.select(sum('value), avg('value), count(1)),
+ testData.select(sum('value), avg('value), count(Literal(1))),
Row(5050.0, 50.5, 100))
checkAnswer(
@@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
- }
-
- test("partition wide sorting") {
- // 2 partitions totally, and
- // Partition #1 with values:
- // (1, 1)
- // (1, 2)
- // (2, 1)
- // Partition #2 with values:
- // (2, 2)
- // (3, 1)
- // (3, 2)
- checkAnswer(
- testData2.sortBy('a.asc, 'b.asc),
- Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
-
- checkAnswer(
- testData2.sortBy('a.asc, 'b.desc),
- Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.desc),
- Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.asc),
- Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest {
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
- test("SPARK-3395 limit distinct") {
- val filtered = TestData.testData2
- .distinct()
- .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending))
- .limit(1)
- .registerTempTable("onerow")
- checkAnswer(
- sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- Row(1, 1, 1, 1) ::
- Row(1, 1, 1, 2) :: Nil)
- }
-
- test("SPARK-3858 generator qualifiers are discarded") {
- checkAnswer(
- arrayData.as('ad)
- .generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
- .select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
- }
-
test("average") {
checkAnswer(
- testData2.aggregate(avg('a)),
+ testData2.agg(avg('a)),
Row(2.0))
checkAnswer(
- testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+ testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a)),
+ decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+ decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+ decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
+ decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
- testData3.aggregate(avg('b)),
+ testData3.agg(avg('b)),
Row(2.0))
checkAnswer(
- testData3.aggregate(avg('b), countDistinct('b)),
+ testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))
checkAnswer(
- testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+ testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
- emptyTableData.aggregate(avg('a)),
+ emptyTableData.agg(avg('a)),
Row(null))
checkAnswer(
- emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+ emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}
@@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest {
assert(testData2.count() === testData2.map(_ => 1).count())
checkAnswer(
- testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+ testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}
test("null count") {
checkAnswer(
- testData3.groupBy('a)('a, count('b)),
+ testData3.groupBy('a).agg('a, count('b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.groupBy('a)('a, count('a + 'b)),
+ testData3.groupBy('a).agg('a, count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
+ testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)
checkAnswer(
- testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+ testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}
@@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest {
assert(emptyTableData.count() === 0)
checkAnswer(
- emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+ emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}
test("zero sum") {
checkAnswer(
- emptyTableData.aggregate(sum('a)),
+ emptyTableData.agg(sum('a)),
Row(null))
}
test("zero sum distinct") {
checkAnswer(
- emptyTableData.aggregate(sumDistinct('a)),
+ emptyTableData.agg(sumDistinct('a)),
Row(null))
}
@@ -320,7 +271,7 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
- testData.select(Star(None), foo.call('key, 'value)).limit(3),
+ testData.select($"*", callUDF(foo, 'key, 'value)).limit(3),
Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
@@ -362,7 +313,7 @@ class DslQuerySuite extends QueryTest {
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase))
)
checkAnswer(
@@ -379,7 +330,7 @@ class DslQuerySuite extends QueryTest {
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase))
)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index cd36da7751e83..79713725c0d77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,19 +20,20 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+
class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
test("equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("multiple-key equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner,
- Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
test("inner join where, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+ upperCaseData.join(lowerCaseData).where('n === 'N),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("inner join ON, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N"),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, where, multiple matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 1).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(1)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Row(1,1,1,1) ::
Row(1,1,1,2) ::
Row(1,2,1,1) ::
@@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, no matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 2).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(2)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Nil)
}
test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.as('x)
- val bigDataY = bigData.as('y)
+ val bigDataX = bigData.as("x")
+ val bigDataY = bigData.as("y")
checkAnswer(
- bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
- testData.flatMap(
- row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
+ bigDataX.join(bigDataY).where($"x.key" === $"y.key"),
+ testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
@@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left outer join") {
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("right outer join") {
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -299,14 +298,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= Literal(4)).registerTempTable("left")
+ upperCaseData.where('N >= Literal(3)).registerTempTable("right")
val left = UnresolvedRelation(Seq("left"), None)
val right = UnresolvedRelation(Seq("right"), None)
checkAnswer(
- left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+ left.join(right, $"left.N" === $"right.N", "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", 3, "C") ::
@@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
@@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 42a21c148df53..07c52de377a60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -26,12 +26,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
+ def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
@@ -44,10 +44,10 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -91,7 +91,7 @@ class QueryTest extends PlanTest {
}
}
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}
@@ -102,7 +102,7 @@ class QueryTest extends PlanTest {
}
/** Asserts that a given SchemaRDD will be executed using the given number of cached results. */
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 03b44ca1d6695..4fff99cb3f3e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.util.TimeZone
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -29,6 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
+
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
@@ -381,8 +383,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("big inner join, 4 matches per row") {
-
-
checkAnswer(
sql(
"""
@@ -396,7 +396,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| SELECT * FROM testData UNION ALL
| SELECT * FROM testData) y
|WHERE x.key = y.key""".stripMargin),
- testData.flatMap(
+ testData.rdd.flatMap(
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
@@ -742,7 +742,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("metadata is propagated correctly") {
- val person = sql("SELECT * FROM person")
+ val person: DataFrame = sql("SELECT * FROM person")
val schema = person.schema
val docKey = "doc"
val docValue = "first name"
@@ -751,14 +751,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person, schemaWithMeta)
- def validateMetadata(rdd: SchemaRDD): Unit = {
+ val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
personWithMeta.registerTempTable("personWithMeta")
- validateMetadata(personWithMeta.select('name))
- validateMetadata(personWithMeta.select("name".attr))
- validateMetadata(personWithMeta.select('id, 'name))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"id", $"name"))
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 808ed5288cfb8..fffa2b7dfa6e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test._
/* Implicits */
@@ -29,11 +30,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +45,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
+ LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +56,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toSchemaRDD
+ TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +68,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toSchemaRDD
+ DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,17 +78,17 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
+ BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toSchemaRDD
+ TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")
- val emptyTableData = logical.LocalRelation('a.int, 'b.int)
+ val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
@@ -97,7 +98,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toSchemaRDD
+ UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +107,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toSchemaRDD
+ LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -200,6 +201,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toSchemaRDD
+ :: Nil).toDF
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 0c98120031242..5abd7b9383366 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
@@ -28,17 +29,17 @@ class UDFSuite extends QueryTest {
test("Simple UDF") {
udf.register("strLenScala", (_: String).length)
- assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
+ assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
udf.register("random0", () => { Math.random()})
- assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
+ assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
udf.register("strLenScala", (_: String).length + (_:Int))
- assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
+ assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("struct UDF") {
@@ -46,7 +47,7 @@ class UDFSuite extends QueryTest {
val result=
sql("SELECT returnStruct('test', 'test2') as ret")
- .select("ret.f1".attr).first().getString(0)
- assert(result == "test")
+ .select($"ret.f1").head().getString(0)
+ assert(result === "test")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index fbc8704f7837b..62b2e89403791 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
+
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
@@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest {
test("register user type: MyDenseVector for MyLabeledPoint") {
- val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
+ val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
val features: RDD[MyDenseVector] =
- pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
+ pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v }
val featuresArrays: Array[MyDenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index e61f3c39631da..6f051dfe3d21d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 67007b8c093ca..be5e63c76f42e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
+
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
@@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite {
}
test("count is partially aggregated") {
- val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
val planned = HashAggregation(query).head
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
@@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite {
}
test("count distinct is partially aggregated") {
- val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
test("mixed aggregates are partially aggregated") {
val query =
- testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
+ testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
@@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite {
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
- val a = testData.as('a)
- val b = table("tiny").as('b)
- val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+ val a = testData.as("a")
+ val b = table("tiny").as("b")
+ val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
deleted file mode 100644
index 272c0d4cb2335..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-
-/* Implicit conversions */
-import org.apache.spark.sql.test.TestSQLContext._
-
-/**
- * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns
- * from the input data. These will be replaced during analysis with specific AttributeReferences
- * and then bound to specific ordinals during query planning. While TGFs could also access specific
- * columns using hand-coded ordinals, doing so violates data independence.
- *
- * Note: this is only a rough example of how TGFs can be expressed, the final version will likely
- * involve a lot more sugar for cleaner use in Scala/Java/etc.
- */
-case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator {
- def children = input
- protected def makeOutput() = 'nameAndAge.string :: Nil
-
- val Seq(nameAttr, ageAttr) = input
-
- override def eval(input: Row): TraversableOnce[Row] = {
- val name = nameAttr.eval(input)
- val age = ageAttr.eval(input).asInstanceOf[Int]
-
- Iterator(
- new GenericRow(Array[Any](s"$name is $age years old")),
- new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old")))
- }
-}
-
-class TgfSuite extends QueryTest {
- val inputData =
- logical.LocalRelation('name.string, 'age.int).loadData(
- ("michael", 29) :: Nil
- )
-
- test("simple tgf example") {
- checkAnswer(
- inputData.generate(ExampleTGF()),
- Seq(
- Row("michael is 29 years old"),
- Row("Next year, michael will be 30 years old")))
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 94d14acccbb18..ef198f846c53a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.{Literal, QueryTest, Row, SQLConf}
class JsonSuite extends QueryTest {
import org.apache.spark.sql.json.TestJsonData._
@@ -463,8 +464,8 @@ class JsonSuite extends QueryTest {
// in the Project.
checkAnswer(
jsonSchemaRDD.
- where('num_str > BigDecimal("92233720368547758060")).
- select('num_str + 1.2 as Symbol("num")),
+ where('num_str > Literal(BigDecimal("92233720368547758060"))).
+ select(('num_str + Literal(1.2)).as("num")),
Row(new java.math.BigDecimal("92233720368547758061.2"))
)
@@ -820,7 +821,7 @@ class JsonSuite extends QueryTest {
val schemaRDD1 = applySchema(rowRDD1, schema1)
schemaRDD1.registerTempTable("applySchema1")
- val schemaRDD2 = schemaRDD1.toSchemaRDD
+ val schemaRDD2 = schemaRDD1.toDF
val result = schemaRDD2.toJSON.collect()
assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
@@ -841,7 +842,7 @@ class JsonSuite extends QueryTest {
val schemaRDD3 = applySchema(rowRDD2, schema2)
schemaRDD3.registerTempTable("applySchema2")
- val schemaRDD4 = schemaRDD3.toSchemaRDD
+ val schemaRDD4 = schemaRDD3.toDF
val result2 = schemaRDD4.toJSON.collect()
assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index 1e7d3e06fc196..c9bc55900de98 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -23,7 +23,7 @@ import parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row}
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -41,15 +41,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
private def checkFilterPredicate(
- rdd: SchemaRDD,
+ rdd: DataFrame,
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
- checker: (SchemaRDD, Seq[Row]) => Unit,
+ checker: (DataFrame, Seq[Row]) => Unit,
expected: Seq[Row]): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
- val query = rdd.select(output: _*).where(predicate)
+ val query = rdd
+ .select(output.map(e => new org.apache.spark.sql.Column(e)): _*)
+ .where(new org.apache.spark.sql.Column(predicate))
val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect {
case plan: ParquetTableScan => plan.columnPruningPred
@@ -71,13 +73,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
private def checkFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected)
}
private def checkFilterPredicate[T]
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T)
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
@@ -93,24 +95,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - integer") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -118,24 +120,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - long") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -143,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - float") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -168,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - double") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -197,30 +199,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPredicate(
'_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
- checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1")
+ checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1")
checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString)))
- checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1")
- checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4")
+ checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1")
+ checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4")
checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
- checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1")
- checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1")
- checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4")
- checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
- checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
+ checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1")
+ checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1")
+ checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4")
+ checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
+ checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
- checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
+ checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3")
- checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
+ checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
}
}
def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: SchemaRDD): Unit = {
- def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = {
+ (implicit rdd: DataFrame): Unit = {
+ def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = {
assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) {
rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
}
@@ -231,7 +233,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte])
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
@@ -249,16 +251,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkBinaryFilterPredicate(
'_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq)
- checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b)
- checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b)
+ checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b)
checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b)
- checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b)
- checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index a57e4e85a35ef..f03b3a32e34e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.DecimalType
-import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
// with an empty configuration (it is after all not intended to be used in this way?)
@@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
}
test("fixed-length decimals") {
- def makeDecimalRDD(decimal: DecimalType): SchemaRDD =
+ def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
- .select('_1 cast decimal)
+ .select($"_1" cast decimal as "abcd")
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 7900b3e8948d9..a33cf1172cac9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.sources
+import scala.language.existentials
+
import org.apache.spark.sql._
import org.apache.spark.sql.types._
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 7385952861ee5..bb19ac232fcbe 100755
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -23,6 +23,7 @@ import java.io._
import java.util.{ArrayList => JArrayList}
import jline.{ConsoleReader, History}
+
import org.apache.commons.lang.StringUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
@@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket
import org.apache.spark.Logging
import org.apache.spark.sql.hive.HiveShim
-import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim
private[hive] object SparkSQLCLIDriver {
private var prompt = "spark-sql"
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index 166c56b9dfe20..ea9d61d8d0f5e 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.Logging
-import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}
+import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation(
sessionToActivePool: SMap[SessionHandle, String])
extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging {
- private var result: SchemaRDD = _
+ private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
@@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.toLocalIterator
+ result.rdd.toLocalIterator
} else {
result.collect().iterator
}
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index eaf7a1ddd4996..71e3954b2c7ac 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.Logging
-import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation(
// NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {
- private var result: SchemaRDD = _
+ private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
@@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.toLocalIterator
+ result.rdd.toLocalIterator
} else {
result.collect().iterator
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 274f83af5ac03..b746942cb1067 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.processors._
+import org.apache.hadoop.hive.ql.parse.VariableSubstitution
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
@@ -63,14 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+ new this.QueryExecution(plan)
- override def sql(sqlText: String): SchemaRDD = {
+ override def sql(sqlText: String): DataFrame = {
+ val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
- super.sql(sqlText)
+ super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText)))
+ new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
@@ -350,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override protected[sql] val planner = hivePlanner
/** Extends QueryExecution with hive specific features. */
- protected[sql] abstract class QueryExecution extends super.QueryExecution {
+ protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
+ extends super.QueryExecution(logicalPlan) {
/**
* Returns the result as a hive compatible sequence of strings. For native commands, the
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 6952b126cf894..ace9329cd5821 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy}
+import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
@@ -55,16 +55,15 @@ private[hive] trait HiveStrategies {
*/
@Experimental
object ParquetConversion extends Strategy {
- implicit class LogicalPlanHacks(s: SchemaRDD) {
- def lowerCase =
- new SchemaRDD(s.sqlContext, s.logicalPlan)
+ implicit class LogicalPlanHacks(s: DataFrame) {
+ def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan)
def addPartitioningAttributes(attrs: Seq[Attribute]) = {
// Don't add the partitioning key if its already present in the data.
if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) {
s
} else {
- new SchemaRDD(
+ new DataFrame(
s.sqlContext,
s.logicalPlan transform {
case p: ParquetRelation => p.copy(partitioningAttributes = attrs)
@@ -97,13 +96,13 @@ private[hive] trait HiveStrategies {
// We are going to throw the predicates and projection back at the whole optimization
// sequence so lets unresolve all the attributes, allowing them to be rebound to the
// matching parquet attributes.
- val unresolvedOtherPredicates = otherPredicates.map(_ transform {
+ val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
- }).reduceOption(And).getOrElse(Literal(true))
+ }).reduceOption(And).getOrElse(Literal(true)))
- val unresolvedProjection = projectList.map(_ transform {
+ val unresolvedProjection: Seq[Column] = projectList.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
- })
+ }).map(new Column(_))
try {
if (relation.hiveQlTable.isPartitioned) {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 47431cef03e13..8e70ae8f56196 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))
override def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+ new this.QueryExecution(plan)
/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
@@ -150,8 +150,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
val describedTable = "DESCRIBE (\\w+)".r
- protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
- lazy val logical = HiveQl.parseSql(hql)
+ protected[hive] class HiveQLQueryExecution(hql: String)
+ extends this.QueryExecution(HiveQl.parseSql(hql)) {
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}
@@ -159,7 +159,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
/**
* Override QueryExecution with special debug workflow.
*/
- abstract class QueryExecution extends super.QueryExecution {
+ class QueryExecution(logicalPlan: LogicalPlan)
+ extends super.QueryExecution(logicalPlan) {
override lazy val analyzed = {
val describedTables = logical match {
case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f320d732fb77a..ba391293884bd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -36,12 +36,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
+ def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
@@ -54,10 +54,10 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -101,7 +101,7 @@ class QueryTest extends PlanTest {
}
}
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index f95a6b43af357..61e5117feab10 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{QueryTest, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.storage.RDDBlockId
class CachedTableSuite extends QueryTest {
@@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest {
* Throws a test failed exception when the number of cached tables differs from the expected
* number.
*/
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 0e6636d38ed3c..5775d83fcbf67 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
+ testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq
)
// Now overwrite.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index df72be7746ac6..d67b00bc9d08f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -27,11 +27,12 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{SQLConf, Row, SchemaRDD}
case class TestData(a: Int, b: String)
@@ -473,7 +474,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
- def isExplanation(result: SchemaRDD) = {
+ def isExplanation(result: DataFrame) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
explanation.contains("== Physical Plan ==")
}
@@ -842,7 +843,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ def collectResults(rdd: DataFrame): Set[(String, String)] =
rdd.collect().map {
case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 16f77a438e1ae..a081227b4e6b6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.Row
import org.apache.spark.util.Utils
@@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest {
sql("create table spark_4959 (col1 string)")
sql("""insert into table spark_4959 select "hi" from src limit 1""")
table("spark_4959").select(
- 'col1.as('CaseSensitiveColName),
- 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2")
+ 'col1.as("CaseSensitiveColName"),
+ 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2")
- assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi"))
- assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi"))
+ assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
+ assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index f2374a215291b..dd0df1a9f6320 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f3,
| getStruct(1).f4,
| getStruct(1).f5 FROM src LIMIT 1
- """.stripMargin).first() === Row(1, 2, 3, 4, 5))
+ """.stripMargin).head() === Row(1, 2, 3, 4, 5))
}
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index f6bf2dbb5d6e4..7f9f1ac7cd80d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest {
)
}
+ test("command substitution") {
+ sql("set tbl=src")
+ checkAnswer(
+ sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"),
+ sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq)
+
+ sql("set hive.variable.substitute=false") // disable the substitution
+ sql("set tbl2=src")
+ intercept[Exception] {
+ sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect()
+ }
+
+ sql("set hive.variable.substitute=true") // enable the substitution
+ checkAnswer(
+ sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"),
+ sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq)
+ }
+
test("ordering not in select") {
checkAnswer(
sql("SELECT key FROM src ORDER BY value"),
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index e59c24adb84af..0e285d6088ec1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
+ /**
+ * Get the maximum remember duration across all the input streams. This is a conservative but
+ * safe remember duration which can be used to perform cleanup operations.
+ */
+ def getMaxInputStreamRememberDuration(): Duration = {
+ inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds }
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
logDebug("DStreamGraph.writeObject used")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index e0542eda1383f..c382a12f4d099 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
+ * @deprecated As this API is not Java compatible.
*/
+ @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0")
def reduceByWindow(
reduceFunc: (T, T) => T,
windowDuration: Duration,
@@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
}
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a sliding window over this DStream.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByWindow(
+ reduceFunc: JFunction2[T, T, T],
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): JavaDStream[T] = {
+ dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
+ }
+
/**
* Return a new DStream in which each RDD has a single element generated by reducing all
* elements in a sliding window over this DStream. However, the reduction is done incrementally
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index afd3c4bc4c4fe..8be04314c4285 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}
-
- /**
- * Clear metadata that are older than `rememberDuration` of this DStream.
- * This is an internal method that should not be called directly. This
- * implementation overrides the default implementation to clear received
- * block information.
- */
- private[streaming] override def clearMetadata(time: Time) {
- super.clearMetadata(time)
- ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
- }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
index ab9fa192191aa..7bf3c33319491 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
@@ -17,7 +17,10 @@
package org.apache.spark.streaming.receiver
-/** Messages sent to the NetworkReceiver. */
+import org.apache.spark.streaming.Time
+
+/** Messages sent to the Receiver. */
private[streaming] sealed trait ReceiverMessage extends Serializable
private[streaming] object StopReceiver extends ReceiverMessage
+private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index d7229c2b96d0b..716cf2c7f32fc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.Time
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl(
case StopReceiver =>
logInfo("Received stop signal")
stop("Stopped by driver", None)
+ case CleanupOldBlocks(threshTime) =>
+ logDebug("Received delete old batch signal")
+ cleanupOldBlocks(threshTime)
}
def ref = self
@@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl(
/** Generate new block ID */
private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
+
+ private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = {
+ logDebug(s"Cleaning up blocks older then $cleanupThreshTime")
+ receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds)
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 39b66e1130768..8632c94349bf9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -17,12 +17,14 @@
package org.apache.spark.streaming.scheduler
-import akka.actor.{ActorRef, ActorSystem, Props, Actor}
-import org.apache.spark.{SparkException, SparkEnv, Logging}
-import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
-import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
import scala.util.{Failure, Success, Try}
+import akka.actor.{ActorRef, Props, Actor}
+
+import org.apache.spark.{SparkEnv, Logging}
+import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
+import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer}
+
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
@@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " +
timesToReschedule.mkString(", "))
- timesToReschedule.foreach(time =>
+ timesToReschedule.foreach { time =>
+ // Allocate the related blocks when recovering from failure, because some blocks that were
+ // added but not allocated, are dangling in the queue after recovering, we have to allocate
+ // those blocks to the next batch, which is the batch they were supposed to go.
+ jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time)))
- )
+ }
// Restart the timer
timer.start(restartTime.milliseconds)
@@ -238,13 +244,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
// If checkpointing is enabled, then checkpoint,
// else mark batch to be fully processed
if (shouldCheckpoint) {
eventActor ! DoCheckpoint(time)
} else {
+ // If checkpointing is not enabled, then delete metadata information about
+ // received blocks (block data not saved in any case). Otherwise, wait for
+ // checkpointing of this batch to complete.
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
}
@@ -252,6 +262,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream checkpoint data for the given `time`. */
private def clearCheckpointData(time: Time) {
ssc.graph.clearCheckpointData(time)
+
+ // All the checkpoint information about which batches have been processed, etc have
+ // been saved to checkpoints, so its safe to delete block metadata and data WAL files
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index c3d9d7b6813d3..e19ac939f9ac5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker(
extends Logging {
private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo]
-
+
private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue]
private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks]
private val logManagerOption = createLogManager()
@@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker(
lastAllocatedBatchTime = batchTime
allocatedBlocks
} else {
- throw new SparkException(s"Unexpected allocation of blocks, " +
- s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ")
+ // This situation occurs when:
+ // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent,
+ // possibly processed batch job or half-processed batch job need to be processed again,
+ // so the batchTime will be equal to lastAllocatedBatchTime.
+ // 2. Slow checkpointing makes recovered batch time older than WAL recovered
+ // lastAllocatedBatchTime.
+ // This situation will only occurs in recovery time.
+ logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery")
}
}
@@ -150,7 +156,6 @@ private[streaming] class ReceivedBlockTracker(
writeToLog(BatchCleanupEvent(timesToCleanup))
timeToAllocatedBlocks --= timesToCleanup
logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion))
- log
}
/** Stop the block tracker. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 8dbb42a86e3bd..4f998869731ed 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -24,9 +24,8 @@ import scala.language.existentials
import akka.actor._
import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
-import org.apache.spark.SparkContext._
import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver}
/**
* Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
}
- /** Clean up metadata older than the given threshold time */
- def cleanupOldMetadata(cleanupThreshTime: Time) {
+ /**
+ * Clean up the data and metadata of blocks and batches that are strictly
+ * older than the threshold time. Note that this does not
+ */
+ def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) {
+ // Clean up old block and batch metadata
receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false)
+
+ // Signal the receivers to delete old block data
+ if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+ logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
+ receiverInfo.values.flatMap { info => Option(info.actor) }
+ .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) }
+ }
}
/** Register a receiver */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
index 27a28bab83ed5..858ba3c9eb4e5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
@@ -63,7 +63,7 @@ private[streaming] object HdfsUtils {
}
def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = {
- // For local file systems, return the raw loca file system, such calls to flush()
+ // For local file systems, return the raw local file system, such calls to flush()
// actually flushes the stream.
val fs = path.getFileSystem(conf)
fs match {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index d92e7fe899a09..d4c40745658c2 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -306,7 +306,17 @@ public void testReduce() {
@SuppressWarnings("unchecked")
@Test
- public void testReduceByWindow() {
+ public void testReduceByWindowWithInverse() {
+ testReduceByWindow(true);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testReduceByWindowWithoutInverse() {
+ testReduceByWindow(false);
+ }
+
+ private void testReduceByWindow(boolean withInverse) {
List> inputData = Arrays.asList(
Arrays.asList(1,2,3),
Arrays.asList(4,5,6),
@@ -319,8 +329,14 @@ public void testReduceByWindow() {
Arrays.asList(24));
JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ JavaDStream reducedWindowed = null;
+ if (withInverse) {
+ reducedWindowed = stream.reduceByWindow(new IntegerSum(),
new IntegerDifference(), new Duration(2000), new Duration(1000));
+ } else {
+ reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ new Duration(2000), new Duration(1000));
+ }
JavaTestUtils.attachTestOutputStream(reducedWindowed);
List> result = JavaTestUtils.runStreams(ssc, 4, 4);
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 40434b1f9b709..6500608bba87c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -28,21 +28,16 @@ import java.io.File
*/
class FailureSuite extends TestSuiteBase with Logging {
- var directory = "FailureSuite"
+ val directory = Utils.createTempDir().getAbsolutePath
val numBatches = 30
override def batchDuration = Milliseconds(1000)
override def useManualClock = false
- override def beforeFunction() {
- super.beforeFunction()
- Utils.deleteRecursively(new File(directory))
- }
-
override def afterFunction() {
- super.afterFunction()
Utils.deleteRecursively(new File(directory))
+ super.afterFunction()
}
test("multiple failures with map") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index de7e9d624bf6b..fbb7b0bfebafc 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite
receivedBlockTracker.allocateBlocksToBatch(2)
receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
- // Verify that batch 2 cannot be allocated again
- intercept[SparkException] {
- receivedBlockTracker.allocateBlocksToBatch(2)
- }
+ // Verify that older batches have no operation on batch allocation,
+ // will return the same blocks as previously allocated.
+ receivedBlockTracker.allocateBlocksToBatch(1)
+ receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos
- // Verify that older batches cannot be allocated again
- intercept[SparkException] {
- receivedBlockTracker.allocateBlocksToBatch(1)
- }
+ blockInfos.map(receivedBlockTracker.addBlock)
+ receivedBlockTracker.allocateBlocksToBatch(2)
+ receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
+ receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
}
test("block addition, block to batch allocation and cleanup with write ahead log") {
@@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite
tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned
tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
}
-
+
test("enabling write ahead log but not setting checkpoint dir") {
conf.set("spark.streaming.receiver.writeAheadLog.enable", "true")
intercept[SparkException] {
createTracker(setCheckpointDir = false)
}
}
-
+
test("setting checkpoint dir but not enabling write ahead log") {
// When WAL config is not set, log manager should not be enabled
val tracker1 = createTracker(setCheckpointDir = true)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index e26c0c6859e57..e8c34a9ee40b9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -17,21 +17,26 @@
package org.apache.spark.streaming
+import java.io.File
import java.nio.ByteBuffer
import java.util.concurrent.Semaphore
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.{StorageLevel, StreamBlockId}
-import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor}
-import org.scalatest.FunSuite
+import com.google.common.io.Files
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
+
/** Testsuite for testing the network receiver behavior */
-class ReceiverSuite extends FunSuite with Timeouts {
+class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
test("receiver life cycle") {
@@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts {
val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")
- println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes)
assert(
// the first and last block may be incomplete, so we slice them out
recordedBlocks.drop(1).dropRight(1).forall { block =>
@@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts {
)
}
-
/**
- * An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
+ * Test whether write ahead logs are generated by received,
+ * and automatically cleaned up. The clean up must be aware of the
+ * remember duration of the input streams. E.g., input streams on which window()
+ * has been applied must remember the data for longer, and hence corresponding
+ * WALs should be cleaned later.
*/
- class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
- @volatile var otherThread: Thread = null
- @volatile var receiving = false
- @volatile var onStartCalled = false
- @volatile var onStopCalled = false
-
- def onStart() {
- otherThread = new Thread() {
- override def run() {
- receiving = true
- while(!isStopped()) {
- Thread.sleep(10)
- }
+ test("write ahead log - generating and cleaning") {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers
+ .setAppName(framework)
+ .set("spark.ui.enabled", "true")
+ .set("spark.streaming.receiver.writeAheadLog.enable", "true")
+ .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+ val batchDuration = Milliseconds(500)
+ val tempDirectory = Files.createTempDir()
+ val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0))
+ val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1))
+ val allLogFiles1 = new mutable.HashSet[String]()
+ val allLogFiles2 = new mutable.HashSet[String]()
+ logInfo("Temp checkpoint directory = " + tempDirectory)
+
+ def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = {
+ (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2))
+ }
+
+ def getCurrentLogFiles(logDirectory: File): Seq[String] = {
+ try {
+ if (logDirectory.exists()) {
+ logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString }
+ } else {
+ Seq.empty
}
+ } catch {
+ case e: Exception =>
+ Seq.empty
}
- onStartCalled = true
- otherThread.start()
-
}
- def onStop() {
- onStopCalled = true
- otherThread.join()
+ def printLogFiles(message: String, files: Seq[String]) {
+ logInfo(s"$message (${files.size} files):\n" + files.mkString("\n"))
}
- def reset() {
- receiving = false
- onStartCalled = false
- onStopCalled = false
+ withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc =>
+ tempDirectory.deleteOnExit()
+ val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiverStream1 = ssc.receiverStream(receiver1)
+ val receiverStream2 = ssc.receiverStream(receiver2)
+ receiverStream1.register()
+ receiverStream2.window(batchDuration * 6).register() // 3 second window
+ ssc.checkpoint(tempDirectory.getAbsolutePath())
+ ssc.start()
+
+ // Run until sufficient WAL files have been generated and
+ // the first WAL files has been deleted
+ eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) {
+ val (logFiles1, logFiles2) = getBothCurrentLogFiles()
+ allLogFiles1 ++= logFiles1
+ allLogFiles2 ++= logFiles2
+ if (allLogFiles1.size > 0) {
+ assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head))
+ }
+ if (allLogFiles2.size > 0) {
+ assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head))
+ }
+ assert(allLogFiles1.size >= 7)
+ assert(allLogFiles2.size >= 7)
+ }
+ ssc.stop(stopSparkContext = true, stopGracefully = true)
+
+ val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted
+ val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted
+ val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles()
+
+ printLogFiles("Receiver 0: all", sortedAllLogFiles1)
+ printLogFiles("Receiver 0: left", leftLogFiles1)
+ printLogFiles("Receiver 1: all", sortedAllLogFiles2)
+ printLogFiles("Receiver 1: left", leftLogFiles2)
+
+ // Verify that necessary latest log files are not deleted
+ // receiverStream1 needs to retain just the last batch = 1 log file
+ // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files
+ assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains))
+ assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains))
}
}
@@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts {
}
}
+/**
+ * An implementation of Receiver that is used for testing a receiver's life cycle.
+ */
+class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+ @volatile var otherThread: Thread = null
+ @volatile var receiving = false
+ @volatile var onStartCalled = false
+ @volatile var onStopCalled = false
+
+ def onStart() {
+ otherThread = new Thread() {
+ override def run() {
+ receiving = true
+ var count = 0
+ while(!isStopped()) {
+ if (sendData) {
+ store(count)
+ count += 1
+ }
+ Thread.sleep(10)
+ }
+ }
+ }
+ onStartCalled = true
+ otherThread.start()
+ }
+
+ def onStop() {
+ onStopCalled = true
+ otherThread.join()
+ }
+
+ def reset() {
+ receiving = false
+ onStartCalled = false
+ onStopCalled = false
+ }
+}
+
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 4c35b60c57df3..d00f29665a58f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -60,7 +60,6 @@ private[yarn] class YarnAllocator(
import YarnAllocator._
- // These two complementary data structures are locked on allocatedHostToContainersMap.
// Visible for testing.
val allocatedHostToContainersMap =
new HashMap[String, collection.mutable.Set[ContainerId]]
@@ -355,20 +354,18 @@ private[yarn] class YarnAllocator(
}
}
- allocatedHostToContainersMap.synchronized {
- if (allocatedContainerToHostMap.containsKey(containerId)) {
- val host = allocatedContainerToHostMap.get(containerId).get
- val containerSet = allocatedHostToContainersMap.get(host).get
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).get
+ val containerSet = allocatedHostToContainersMap.get(host).get
- containerSet.remove(containerId)
- if (containerSet.isEmpty) {
- allocatedHostToContainersMap.remove(host)
- } else {
- allocatedHostToContainersMap.update(host, containerSet)
- }
-
- allocatedContainerToHostMap.remove(containerId)
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
}
+
+ allocatedContainerToHostMap.remove(containerId)
}
}
}