diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2b..594fa0c779e1b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -142,8 +142,10 @@ com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier diff --git a/bin/spark-class b/bin/spark-class index 1b945461fabc8..2f0441bb3c1c2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" . "$FWDIR"/bin/load-spark-env.sh @@ -120,8 +121,8 @@ fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" +if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then + JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! diff --git a/build/mvn b/build/mvn index 43471f83e904c..f91e2b4bdcc02 100755 --- a/build/mvn +++ b/build/mvn @@ -68,10 +68,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08afc..1984682b9c099 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,8 +372,10 @@ com.google.guava:guava com/google/common/base/Absent* + com/google/common/base/Function com/google/common/base/Optional* com/google/common/base/Present* + com/google/common/base/Supplier diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a1f7133f897ee..f23ba9dba167f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -190,6 +190,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, +.getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e35..419d093d55643 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..cd91c8f87547b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,8 +17,11 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap + import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer /** @@ -46,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + set(k, v) } } @@ -63,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + settings.put(key, value) this } @@ -129,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + this.settings.putAll(settings.toMap.asJava) this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value - } + settings.putIfAbsent(key, value) this } @@ -163,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) } /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(key)) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -224,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(key) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -240,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -265,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -345,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a354ed4d1486..4c4ee04cc515e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + @volatile private var stopped: Boolean = false + + private def assertNotStopped(): Unit = { + if (stopped) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * modified collection. Pass a copy of the argument to avoid this. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], @@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() new NewHadoopRDD(this, fClass, kClass, vClass, conf) } @@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { postApplicationEnd() ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { + if (!stopped) { + stopped = true env.metricsSystem.report() metadataCleaner.cancel() env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() + dagScheduler.stop() + dagScheduler = null taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) @@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1468,13 +1505,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getCheckpointDir = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33f..1264a8126153b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -326,6 +326,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1d..211e3ede53d9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() @@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() @@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging { } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b78..0ae45f4ad9130 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!logInfos.isEmpty) { val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { + if (!newApps.contains(info.id) || + newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { newApps += (info.id -> info) } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d71..823825302658c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend( } case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) + if (x.remoteAddress == driver.anchorPath.address) { + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"Received irrelevant DisassociatedEvent $x") + } case StopExecutor => logInfo("Driver commanded a shutdown") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a14093..d8c2e41a7c715 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, isLocal: Boolean = false) extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -58,12 +61,12 @@ private[spark] class Executor( @volatile private var isStopped = false // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ddb5903bf6875..97912c68c5982 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.executor import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.DataReadMethod import org.apache.spark.executor.DataReadMethod.DataReadMethod import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01dd..83e8eb71260eb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,8 +130,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 056aef0bc210a..c3e3931042de2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -218,13 +219,13 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -254,7 +255,8 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b0e3c87ccff4..d86f95ac3e485 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -34,7 +34,7 @@ import org.apache.spark.Logging import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils @@ -114,13 +114,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -163,7 +163,8 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0f37d830ef34f..49b88a90ab5af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { @@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1086,11 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97012c7033f9f..ab7410a1f7f99 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0a..4307029d44fbb 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,8 +24,10 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." val INPUT = "Bytes read from Hadoop or from Spark storage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c0..d8be1b20b3acd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time + {if (hasShuffleRead) { +
  • + + + Shuffle Read Blocked Time + +
  • + }}
  • @@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read", "")) + } else { + Nil + }} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) else Nil} ++ @@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = Shuffle Read Blocked Time +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadQuantiles} + } else { + Nil + }, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) @@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") + val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime) + val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasShuffleRead) { + + {shuffleReadBlockedTimeReadable} + {shuffleReadReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde42..37cf2c207ba40 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc920..21487bc24d58a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) }.reduceByKey(_+_) val f1 = rdd.collectAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..af3272692d7a1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e7..3fbc1a21d10ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set("spark.testing", "true") + val provider = new FsHistoryProvider(conf) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(new File(testDir, "app1")) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { val out = diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd7..372d7aa453008 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80ce..81db66ae17464 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -26,7 +26,16 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, + LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, + TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil @@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } } + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 381ee2d45630f..e33b4bbbb8e4c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336de..8cf951adb354b 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 10541f878476c..1026cb2aa7cae 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() (1 to 100).foreach(eventLoop.post) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert((1 to 100) === buffer.toSeq) } eventLoop.stop() @@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) } eventLoop.stop() @@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) assert(eventLoop.isActive) } @@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts { }.start() } - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(threadNum * eventsFromEachThread === receivedEventsCount) } eventLoop.stop() @@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts { } assert(false === eventLoop.isActive) } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098b..b2a7e092a0291 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id diff --git a/docs/configuration.md b/docs/configuration.md index efbab4085317a..7c5b6d011cfd3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment + + + + + + + + + + + + + + + diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 2094963392295..ef18cec9371d6 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,10 +204,10 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) {% endhighlight %} @@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %} diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9f..77c0abbbacbd0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfafe..14a87f8436984 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 2adc63f7ff30e..387c0e421334b 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -76,7 +76,7 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args - // Zookeper connection properties + // Zookeeper connection properties val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8c..0fbee6e433608 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,7 +33,7 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + cvModel.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e838..eaaa344be49c8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -48,13 +48,13 @@ public static void main(String[] args) { // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = + model2.transform(test).registerTempTable("results"); + DataFrame results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7d..82d665a3e1386 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + model.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c9..8defb769ffaaf 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -74,13 +74,13 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override @@ -93,17 +93,17 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,13 +130,13 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6ea..b5a70db2b9a3c 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,7 +16,7 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb8..7f5c68e3d0fe2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -30,18 +30,18 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.inferSchema(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.applySchema(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) @@ -49,7 +49,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,7 +61,7 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..283bb80f1c788 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator @@ -101,7 +100,7 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala new file mode 100644 index 0000000000000..b7885829459a3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.recommendation.ALS +import org.apache.spark.sql.{Row, SQLContext} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + * Run with + * {{{ + * bin/run-example ml.MovieLensALS + * }}} + */ +object MovieLensALS { + + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + + case class Movie(movieId: Int, title: String, genres: Seq[String]) + + object Movie { + def parseMovie(str: String): Movie = { + val fields = str.split("::") + assert(fields.size == 3) + Movie(fields(0).toInt, fields(1), fields(2).split("|")) + } + } + + case class Params( + ratings: String = null, + movies: String = null, + maxIter: Int = 10, + regParam: Double = 0.1, + rank: Int = 10, + numBlocks: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[String]("ratings") + .required() + .text("path to a MovieLens dataset of ratings") + .action((x, c) => c.copy(ratings = x)) + opt[String]("movies") + .required() + .text("path to a MovieLens dataset of movies") + .action((x, c) => c.copy(movies = x)) + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("maxIter") + .text(s"max number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Int]("numBlocks") + .text(s"number of blocks, default: ${defaultParams.numBlocks}") + .action((x, c) => c.copy(numBlocks = x)) + note( + """ + |Example command line to run this app: + | + | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ + | examples/target/scala-*/spark-examples-*.jar \ + | --rank 10 --maxIter 15 --regParam 0.1 \ + | --movies path/to/movielens/movies.dat \ + | --ratings path/to/movielens/ratings.dat + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.userId).distinct().count() + val numMovies = ratings.map(_.movieId).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val als = new ALS() + .setUserCol("userId") + .setItemCol("movieId") + .setRank(params.rank) + .setMaxIter(params.maxIter) + .setRegParam(params.regParam) + .setNumBlocks(params.numBlocks) + + val model = als.fit(training) + + val predictions = model.transform(test).cache() + + // Evaluate the model. + // TODO: Create an evaluator to compute RMSE. + val mse = predictions.select("rating", "prediction").rdd + .flatMap { case Row(rating: Float, prediction: Float) => + val err = rating.toDouble - prediction + val err2 = err * err + if (err2.isNaN) { + None + } else { + Some(err2) + } + }.mean() + val rmse = math.sqrt(mse) + println(s"Test RMSE = $rmse.") + + // Inspect false positives. + predictions.registerTempTable("prediction") + sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") + sqlContext.sql( + """ + |SELECT userId, prediction.movieId, title, rating, prediction + | FROM prediction JOIN movie ON prediction.movieId = movie.movieId + | WHERE rating <= 1 AND prediction >= 4 + | LIMIT 100 + """.stripMargin) + .collect() + .foreach(println) + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..95cc9801eaeb9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -42,7 +41,7 @@ object SimpleParamsExample { // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. val training = sparkContext.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -92,7 +91,7 @@ object SimpleParamsExample { // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + .select("features", "label", "probability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..065db62b0f5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} @@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..f229a58985a3e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDF + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns, using implicit conversion to DataFrames. + val labelsDf: DataFrame = origData.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = origData.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..a5d7f262581f5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,8 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.dsl._ +import org.apache.spark.sql.dsl.literals._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -54,7 +56,7 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. rdd.saveAsParquetFile("pair.parquet") @@ -63,7 +65,7 @@ object RDDRelation { val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..f1550ac2e18ad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb6..ed9876b8dc21c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + } finally { + sc.stop() + } + } + } diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be3053..0adca7851819b 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -115,7 +115,7 @@ if which git &>/dev/null; then unset GITREV fi -if ! which $MVN &>/dev/null; then +if ! which "$MVN" &>/dev/null; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; @@ -171,13 +171,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a122..bc3defe968afd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMap parameter map * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMaps an array of parameter maps * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..fe39cd1bc0bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -162,7 +162,7 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c435351..b233bff08305c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,9 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types._ /** @@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.select($"*", callUDF( + this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..eeb6301c3f64a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val score: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) @@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] ( val predict: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..1979ab9eb6516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** @@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params def setScoreCol(value: String): this.type = set(scoreCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema @@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params require(labelType == DoubleType, s"Label column ${map(labelCol)} must be double type but found $labelType") - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) .map { case Row(score: Double, label: Double) => (score, label) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..e7bdb070c8193 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} @@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) @@ -83,14 +79,13 @@ class StandardScalerModel private[ml] ( def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala new file mode 100644 index 0000000000000..f6437c7fbc8ed --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -0,0 +1,970 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.{util => ju} + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.dsl._ +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common params for ALS. + */ +private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam + with HasPredictionCol { + + /** Param for rank of the matrix factorization. */ + val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + def getRank: Int = get(rank) + + /** Param for number of user blocks. */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + def getNumUserBlocks: Int = get(numUserBlocks) + + /** Param for number of item blocks. */ + val numItemBlocks = + new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + def getNumItemBlocks: Int = get(numItemBlocks) + + /** Param to decide whether to use implicit preference. */ + val implicitPrefs = + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + def getImplicitPrefs: Boolean = get(implicitPrefs) + + /** Param for the alpha parameter in the implicit preference formulation. */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + def getAlpha: Double = get(alpha) + + /** Param for the column name for user ids. */ + val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + def getUserCol: String = get(userCol) + + /** Param for the column name for item ids. */ + val itemCol = + new Param[String](this, "itemCol", "column name for item ids", Some("item")) + def getItemCol: String = get(itemCol) + + /** Param for the column name for ratings. */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + def getRatingCol: String = get(ratingCol) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @param paramMap extra params + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + assert(schema(map(userCol)).dataType == IntegerType) + assert(schema(map(itemCol)).dataType== IntegerType) + val ratingType = schema(map(ratingCol)).dataType + assert(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = map(predictionCol) + assert(!schema.fieldNames.contains(predictionColName), + s"Prediction column $predictionColName already exists.") + val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + StructType(newFields) + } +} + +/** + * Model fitted by ALS. + */ +class ALSModel private[ml] ( + override val parent: ALS, + override val fittingParamMap: ParamMap, + k: Int, + userFactors: RDD[(Int, Array[Float])], + itemFactors: RDD[(Int, Array[Float])]) + extends Model[ALSModel] with ALSParams { + + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + import dataset.sqlContext._ + import org.apache.spark.ml.recommendation.ALSModel.Factor + val map = this.paramMap ++ paramMap + // TODO: Add DSL to simplify the code here. + val instanceTable = s"instance_$uid" + val userTable = s"user_$uid" + val itemTable = s"item_$uid" + val instances = dataset.as(instanceTable) + val users = userFactors.map { case (id, features) => + Factor(id, features) + }.as(userTable) + val items = itemFactors.map { case (id, features) => + Factor(id, features) + }.as(itemTable) + val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + if (userFeatures != null && itemFeatures != null) { + blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + val inputColumns = dataset.schema.fieldNames + val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features") + .as(map(predictionCol)) + val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction + instances + .join(users, Column(map(userCol)) === $"$userTable.id", "left") + .join(items, Column(map(itemCol)) === $"$itemTable.id", "left") + .select(outputColumns: _*) + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private object ALSModel { + /** Case class to convert factors to SchemaRDDs */ + private case class Factor(id: Int, features: Seq[Float]) +} + +/** + * Alternating Least Squares (ALS) matrix factorization. + * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by pre-computing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * indicated user + * preferences rather than explicit ratings given to items. + */ +class ALS extends Estimator[ALSModel] with ALSParams { + + import org.apache.spark.ml.recommendation.ALS.Rating + + def setRank(value: Int): this.type = set(rank, value) + def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + def setAlpha(value: Double): this.type = set(alpha, value) + def setUserCol(value: String): this.type = set(userCol, value) + def setItemCol(value: String): this.type = set(itemCol, value) + def setRatingCol(value: String): this.type = set(ratingCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setRegParam(value: Double): this.type = set(regParam, value) + + /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + def setNumBlocks(value: Int): this.type = { + setNumUserBlocks(value) + setNumItemBlocks(value) + this + } + + setMaxIter(20) + setRegParam(1.0) + + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { + val map = this.paramMap ++ paramMap + val ratings = dataset + .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), + numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), + maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), + alpha = map(alpha)) + val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) + Params.inheritValues(map, this, model) + model + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private[recommendation] object ALS extends Logging { + + /** Rating class for better code readability. */ + private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + + /** Cholesky solver for least square problems. */ + private[recommendation] class CholeskySolver { + + private val upper = "U" + private val info = new intW(0) + + /** + * Solves a least squares problem with L2 regularization: + * + * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * + * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) + * @param lambda regularization constant, which will be scaled by n + * @return the solution x + */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val k = ne.k + // Add scaled lambda to the diagonals of AtA. + val scaledlambda = lambda * ne.n + var i = 0 + var j = 2 + while (i < ne.triK) { + ne.ata(i) += scaledlambda + i += j + j += 1 + } + lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dppsv returned $code.") + val x = new Array[Float](k) + i = 0 + while (i < k) { + x(i) = ne.atb(i).toFloat + i += 1 + } + ne.reset() + x + } + } + + /** Representing a normal equation (ALS' subproblem). */ + private[recommendation] class NormalEquation(val k: Int) extends Serializable { + + /** Number of entries in the upper triangular part of a k-by-k matrix. */ + val triK = k * (k + 1) / 2 + /** A^T^ * A */ + val ata = new Array[Double](triK) + /** A^T^ * b */ + val atb = new Array[Double](k) + /** Number of observations. */ + var n = 0 + + private val da = new Array[Double](k) + private val upper = "U" + + private def copyToDouble(a: Array[Float]): Unit = { + var i = 0 + while (i < k) { + da(i) = a(i) + i += 1 + } + } + + /** Adds an observation. */ + def add(a: Array[Float], b: Float): this.type = { + require(a.size == k) + copyToDouble(a) + blas.dspr(upper, k, 1.0, da, 1, ata) + blas.daxpy(k, b.toDouble, da, 1, atb, 1) + n += 1 + this + } + + /** + * Adds an observation with implicit feedback. Note that this does not increment the counter. + */ + def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + require(a.size == k) + // Extension to the original paper to handle b < 0. confidence is a function of |b| instead + // so that it is never negative. + val confidence = 1.0 + alpha * math.abs(b) + copyToDouble(a) + blas.dspr(upper, k, confidence - 1.0, da, 1, ata) + // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. + if (b > 0) { + blas.daxpy(k, confidence, da, 1, atb, 1) + } + this + } + + /** Merges another normal equation object. */ + def merge(other: NormalEquation): this.type = { + require(other.k == k) + blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) + n += other.n + this + } + + /** Resets everything to zero, which should be called after each solve. */ + def reset(): Unit = { + ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(atb, 0.0) + n = 0 + } + } + + /** + * Implementation of the ALS algorithm. + */ + private def train( + ratings: RDD[Rating], + rank: Int = 10, + numUserBlocks: Int = 10, + numItemBlocks: Int = 10, + maxIter: Int = 10, + regParam: Double = 1.0, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { + val userPart = new HashPartitioner(numUserBlocks) + val itemPart = new HashPartitioner(numItemBlocks) + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() + val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + // materialize blockRatings and user blocks + userOutBlocks.count() + val swappedBlockRatings = blockRatings.map { + case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => + ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) + } + val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + // materialize item blocks + itemOutBlocks.count() + var userFactors = initialize(userInBlocks, rank) + var itemFactors = initialize(itemInBlocks, rank) + if (implicitPrefs) { + for (iter <- 1 to maxIter) { + userFactors.setName(s"userFactors-$iter").persist() + val previousItemFactors = itemFactors + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, implicitPrefs, alpha) + previousItemFactors.unpersist() + itemFactors.setName(s"itemFactors-$iter").persist() + val previousUserFactors = userFactors + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, implicitPrefs, alpha) + previousUserFactors.unpersist() + } + } else { + for (iter <- 0 until maxIter) { + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder) + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder) + } + } + val userIdAndFactors = userInBlocks + .mapValues(_.srcIds) + .join(userFactors) + .values + .setName("userFactors") + .cache() + userIdAndFactors.count() + itemFactors.unpersist() + val itemIdAndFactors = itemInBlocks + .mapValues(_.srcIds) + .join(itemFactors) + .values + .setName("itemFactors") + .cache() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + val userOutput = userIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + (userOutput, itemOutput) + } + + /** + * Factor block that stores factors (Array[Float]) in an Array. + */ + private type FactorBlock = Array[Array[Float]] + + /** + * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to + * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the + * src factors in this block to send to dst block 0. + */ + private type OutBlock = Array[Array[Int]] + + /** + * In-link block for computing src (user/item) factors. This includes the original src IDs + * of the elements within this block as well as encoded dst (item/user) indices and corresponding + * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original + * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. + * For example, if we have an in-link record + * + * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * + * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which + * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * + * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can + * compute src factors one after another using only one normal equation instance. + * + * @param srcIds src ids (ordered) + * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and + * ratings are associated with srcIds(i). + * @param dstEncodedIndices encoded dst indices + * @param ratings ratings + * + * @see [[LocalIndexEncoder]] + */ + private[recommendation] case class InBlock( + srcIds: Array[Int], + dstPtrs: Array[Int], + dstEncodedIndices: Array[Int], + ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = ratings.size + + require(dstEncodedIndices.size == size) + require(dstPtrs.size == srcIds.size + 1) + } + + /** + * Initializes factors randomly given the in-link blocks. + * + * @param inBlocks in-link blocks + * @param rank rank + * @return initialized factor blocks + */ + private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + // Choose a unit vector uniformly at random from the unit sphere, but from the + // "first quadrant" where all elements are nonnegative. This can be done by choosing + // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. + // This appears to create factorizations that have a slightly better reconstruction + // (<1%) compared picking elements uniformly at random in [0,1]. + inBlocks.map { case (srcBlockId, inBlock) => + val random = new XORShiftRandom(srcBlockId) + val factors = Array.fill(inBlock.srcIds.size) { + val factor = Array.fill(rank)(random.nextGaussian().toFloat) + val nrm = blas.snrm2(rank, factor, 1) + blas.sscal(rank, 1.0f / nrm, factor, 1) + factor + } + (srcBlockId, factors) + } + } + + /** + * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. + */ + private[recommendation] + case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = srcIds.size + require(dstIds.size == size) + require(ratings.size == size) + } + + /** + * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. + */ + private[recommendation] class RatingBlockBuilder extends Serializable { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstIds = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + var size = 0 + + /** Adds a rating. */ + def add(r: Rating): this.type = { + size += 1 + srcIds += r.user + dstIds += r.item + ratings += r.rating + this + } + + /** Merges another [[RatingBlockBuilder]]. */ + def merge(other: RatingBlock): this.type = { + size += other.srcIds.size + srcIds ++= other.srcIds + dstIds ++= other.dstIds + ratings ++= other.ratings + this + } + + /** Builds a [[RatingBlock]]. */ + def build(): RatingBlock = { + RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + } + } + + /** + * Partitions raw ratings into blocks. + * + * @param ratings raw ratings + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * + * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) + */ + private def partitionRatings( + ratings: RDD[Rating], + srcPart: Partitioner, + dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + + /* The implementation produces the same result as the following but generates less objects. + + ratings.map { r => + ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + }.aggregateByKey(new RatingBlockBuilder)( + seqOp = (b, r) => b.add(r), + combOp = (b0, b1) => b0.merge(b1.build())) + .mapValues(_.build()) + */ + + val numPartitions = srcPart.numPartitions * dstPart.numPartitions + ratings.mapPartitions { iter => + val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + iter.flatMap { r => + val srcBlockId = srcPart.getPartition(r.user) + val dstBlockId = dstPart.getPartition(r.item) + val idx = srcBlockId + srcPart.numPartitions * dstBlockId + val builder = builders(idx) + builder.add(r) + if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k + builders(idx) = new RatingBlockBuilder + Iterator.single(((srcBlockId, dstBlockId), builder.build())) + } else { + Iterator.empty + } + } ++ { + builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => + val srcBlockId = idx % srcPart.numPartitions + val dstBlockId = idx / srcPart.numPartitions + ((srcBlockId, dstBlockId), block.build()) + } + } + }.groupByKey().mapValues { blocks => + val builder = new RatingBlockBuilder + blocks.foreach(builder.merge) + builder.build() + }.setName("ratingBlocks") + } + + /** + * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. + * @param encoder encoder for dst indices + */ + private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + + /** + * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. + * + * @param dstBlockId dst block ID + * @param srcIds original src IDs + * @param dstLocalIndices dst local indices + * @param ratings ratings + */ + def add( + dstBlockId: Int, + srcIds: Array[Int], + dstLocalIndices: Array[Int], + ratings: Array[Float]): this.type = { + val sz = srcIds.size + require(dstLocalIndices.size == sz) + require(ratings.size == sz) + this.srcIds ++= srcIds + this.ratings ++= ratings + var j = 0 + while (j < sz) { + this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) + j += 1 + } + this + } + + /** Builds a [[UncompressedInBlock]]. */ + def build(): UncompressedInBlock = { + new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) + } + } + + /** + * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. + */ + private[recommendation] class UncompressedInBlock( + val srcIds: Array[Int], + val dstEncodedIndices: Array[Int], + val ratings: Array[Float]) { + + /** Size the of block. */ + def size: Int = srcIds.size + + /** + * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a + * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. + */ + def compress(): InBlock = { + val sz = size + assert(sz > 0, "Empty in-link block should not exist.") + sort() + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val dstCountsBuilder = mutable.ArrayBuilder.make[Int] + var preSrcId = srcIds(0) + uniqueSrcIdsBuilder += preSrcId + var curCount = 1 + var i = 1 + var j = 0 + while (i < sz) { + val srcId = srcIds(i) + if (srcId != preSrcId) { + uniqueSrcIdsBuilder += srcId + dstCountsBuilder += curCount + preSrcId = srcId + j += 1 + curCount = 0 + } + curCount += 1 + i += 1 + } + dstCountsBuilder += curCount + val uniqueSrcIds = uniqueSrcIdsBuilder.result() + val numUniqueSrdIds = uniqueSrcIds.size + val dstCounts = dstCountsBuilder.result() + val dstPtrs = new Array[Int](numUniqueSrdIds + 1) + var sum = 0 + i = 0 + while (i < numUniqueSrdIds) { + sum += dstCounts(i) + i += 1 + dstPtrs(i) = sum + } + InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) + } + + private def sort(): Unit = { + val sz = size + // Since there might be interleaved log messages, we insert a unique id for easy pairing. + val sortId = Utils.random.nextInt() + logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") + val start = System.nanoTime() + val sorter = new Sorter(new UncompressedInBlockSort) + sorter.sort(this, 0, size, Ordering[IntWrapper]) + val duration = (System.nanoTime() - start) / 1e9 + logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") + } + } + + /** + * A wrapper that holds a primitive integer key. + * + * @see [[UncompressedInBlockSort]] + */ + private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + override def compare(that: IntWrapper): Int = { + key.compare(that.key) + } + } + + /** + * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. + */ + private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + + override def newKey(): IntWrapper = new IntWrapper() + + override def getKey( + data: UncompressedInBlock, + pos: Int, + reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data.srcIds(pos)) + } else { + reuse.key = data.srcIds(pos) + reuse + } + } + + override def getKey( + data: UncompressedInBlock, + pos: Int): IntWrapper = { + getKey(data, pos, null) + } + + private def swapElements[@specialized(Int, Float) T]( + data: Array[T], + pos0: Int, + pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + swapElements(data.srcIds, pos0, pos1) + swapElements(data.dstEncodedIndices, pos0, pos1) + swapElements(data.ratings, pos0, pos1) + } + + override def copyRange( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int, + length: Int): Unit = { + System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) + System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) + System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) + } + + override def allocate(length: Int): UncompressedInBlock = { + new UncompressedInBlock( + new Array[Int](length), new Array[Int](length), new Array[Float](length)) + } + + override def copyElement( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int): Unit = { + dst.srcIds(dstPos) = src.srcIds(srcPos) + dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) + dst.ratings(dstPos) = src.ratings(srcPos) + } + } + + /** + * Creates in-blocks and out-blocks from rating blocks. + * @param prefix prefix for in/out-block names + * @param ratingBlocks rating blocks + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * @return (in-blocks, out-blocks) + */ + private def makeBlocks( + prefix: String, + ratingBlocks: RDD[((Int, Int), RatingBlock)], + srcPart: Partitioner, + dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + val inBlocks = ratingBlocks.map { + case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => + // The implementation is a faster version of + // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap + val start = System.nanoTime() + val dstIdSet = new OpenHashSet[Int](1 << 20) + dstIds.foreach(dstIdSet.add) + val sortedDstIds = new Array[Int](dstIdSet.size) + var i = 0 + var pos = dstIdSet.nextPos(0) + while (pos != -1) { + sortedDstIds(i) = dstIdSet.getValue(pos) + pos = dstIdSet.nextPos(pos + 1) + i += 1 + } + assert(i == dstIdSet.size) + ju.Arrays.sort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + i = 0 + while (i < sortedDstIds.size) { + dstIdToLocalIndex.update(sortedDstIds(i), i) + i += 1 + } + logDebug( + "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") + val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) + (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) + }.groupByKey(new HashPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks").cache() + val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => + val encoder = new LocalIndexEncoder(dstPart.numPartitions) + val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) + var i = 0 + val seen = new Array[Boolean](dstPart.numPartitions) + while (i < srcIds.size) { + var j = dstPtrs(i) + ju.Arrays.fill(seen, false) + while (j < dstPtrs(i + 1)) { + val dstBlockId = encoder.blockId(dstEncodedIndices(j)) + if (!seen(dstBlockId)) { + activeIds(dstBlockId) += i // add the local index in this out-block + seen(dstBlockId) = true + } + j += 1 + } + i += 1 + } + activeIds.map { x => + x.result() + } + }.setName(prefix + "OutBlocks").cache() + (inBlocks, outBlocks) + } + + /** + * Compute dst factors by constructing and solving least square problems. + * + * @param srcFactorBlocks src factors + * @param srcOutBlocks src out-blocks + * @param dstInBlocks dst in-blocks + * @param rank rank + * @param regParam regularization constant + * @param srcEncoder encoder for src local indices + * @param implicitPrefs whether to use implicit preference + * @param alpha the alpha constant in the implicit preference formulation + * + * @return dst factors + */ + private def computeFactors( + srcFactorBlocks: RDD[(Int, FactorBlock)], + srcOutBlocks: RDD[(Int, OutBlock)], + dstInBlocks: RDD[(Int, InBlock)], + rank: Int, + regParam: Double, + srcEncoder: LocalIndexEncoder, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.size + val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None + val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { + case (srcBlockId, (srcOutBlock, srcFactors)) => + srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => + (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) + } + } + val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + dstInBlocks.join(merged).mapValues { + case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => + val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) + srcFactors.foreach { case (srcBlockId, factors) => + sortedSrcFactors(srcBlockId) = factors + } + val dstFactors = new Array[Array[Float]](dstIds.size) + var j = 0 + val ls = new NormalEquation(rank) + val solver = new CholeskySolver // TODO: add NNLS solver + while (j < dstIds.size) { + ls.reset() + if (implicitPrefs) { + ls.merge(YtY.get) + } + var i = srcPtrs(j) + while (i < srcPtrs(j + 1)) { + val encoded = srcEncodedIndices(i) + val blockId = srcEncoder.blockId(encoded) + val localIndex = srcEncoder.localIndex(encoded) + val srcFactor = sortedSrcFactors(blockId)(localIndex) + val rating = ratings(i) + if (implicitPrefs) { + ls.addImplicit(srcFactor, rating, alpha) + } else { + ls.add(srcFactor, rating) + } + i += 1 + } + dstFactors(j) = solver.solve(ls, regParam) + j += 1 + } + dstFactors + } + } + + /** + * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. + * Caching of the input factors is handled in [[ALS#train]]. + */ + private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { + factorBlocks.values.aggregate(new NormalEquation(rank))( + seqOp = (ne, factors) => { + factors.foreach(ne.add(_, 0.0f)) + ne + }, + combOp = (ne1, ne2) => ne1.merge(ne2)) + } + + /** + * Encoder for storing (blockId, localIndex) into a single integer. + * + * We use the leading bits (including the sign bit) to store the block id and the rest to store + * the local index. This is based on the assumption that users/items are approximately evenly + * partitioned. With this assumption, we should be able to encode two billion distinct values. + * + * @param numBlocks number of blocks + */ + private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { + + require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") + + private[this] final val numLocalIndexBits = + math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) + private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 + + /** Encodes a (blockId, localIndex) into a single integer. */ + def encode(blockId: Int, localIndex: Int): Int = { + require(blockId < numBlocks) + require((localIndex & ~localIndexMask) == 0) + (blockId << numLocalIndexBits) | localIndex + } + + /** Gets the block id from an encoded index. */ + @inline + def blockId(encoded: Int): Int = { + encoded >>> numLocalIndexBits + } + + /** Gets the local index from an encoded index. */ + @inline + def localIndex(encoded: Int): Int = { + encoded & localIndexMask + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..5d51c51346665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setEvaluator(value: Evaluator): this.type = set(evaluator, value) def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.applySchema(training, schema).cache() val validationDataset = sqlCtx.applySchema(validation, schema).cache() @@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 6b5c934f015ba..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,45 +279,81 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => - (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) - }.map((_, p)) + pointsWithCosts.flatMap { case (p, c) => + val rs = (0 until runs).filter { r => + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() - chosen.foreach { case (r, p) => - centers(r) += p.toDense + mergeNewCenters() + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca4..34e0392f1b21a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -257,80 +257,58 @@ private[spark] object BLAS extends Serializable with Logging { /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) + gemm(alpha, sparse, B, beta, C) case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +316,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +334,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +361,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +386,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +408,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +425,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +432,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +477,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3c..ad7e86827b368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] @@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() @@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } /** @@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable { * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) override def equals(o: Any) = o match { case m: DenseMatrix => @@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v @@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. */ def toSparse(): SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -271,49 +340,73 @@ object DenseMatrix { * @param rowIndices the row index of the entry. They must be in strictly increasing order for each * column * @param values non-zero matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - private[mllib] def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,7 +415,7 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } @@ -341,7 +434,38 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + override def transpose: Matrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. */ def toDense(): DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } @@ -557,10 +681,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -679,46 +802,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +849,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +862,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7ee0224ad4662..2834ea75ceb8f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -333,7 +333,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -341,8 +341,9 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, "vector dimension mismatch") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -350,12 +351,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -370,18 +371,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } @@ -397,7 +403,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index bee951a2e5e26..5f84677be238d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double) * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c650..482dd4b272d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -140,6 +140,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +156,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..3308adb6752ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -156,6 +156,9 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ @@ -173,11 +176,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa3774..b7950e00786ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a0..c946db9c0d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c260..56a9dbdd58b64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb58..f4ba23c44563e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -55,7 +55,7 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } @@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca7..074b58c07df7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..1912afce93b18 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -21,12 +21,12 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "prediction") .collect() } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + .select("label", "score", "prediction") .collect() } test("logistic regression fit and transform with varargs") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) + .select("label", "probability", "prediction") .collect() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala new file mode 100644 index 0000000000000..58289acdbc095 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.util.Random + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.scalatest.FunSuite + +import org.apache.spark.Logging +import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { + + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("LocalIndexEncoder") { + val random = new Random + for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { + val encoder = new LocalIndexEncoder(numBlocks) + val maxLocalIndex = Int.MaxValue / numBlocks + val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++ + Seq((0, 0), (numBlocks - 1, maxLocalIndex)) + tests.foreach { case (blockId, localIndex) => + val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex." + val encoded = encoder.encode(blockId, localIndex) + assert(encoder.blockId(encoded) === blockId, err) + assert(encoder.localIndex(encoded) === localIndex, err) + } + } + } + + test("normal equation construction with explict feedback") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 3.0f) + .add(Array(4.0f, 5.0f), 6.0f) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 2) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5") + // b = np.matrix("3; 6") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + + val ne1 = new NormalEquation(2) + .add(Array(7.0f, 8.0f), 9.0f) + ne0.merge(ne1) + assert(ne0.n === 3) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5; 7 8") + // b = np.matrix("3; 6; 9") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f), 2.0f) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + } + intercept[IllegalArgumentException] { + val ne2 = new NormalEquation(3) + ne0.merge(ne2) + } + + ne0.reset() + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + } + + test("normal equation construction with implicit feedback") { + val k = 2 + val alpha = 0.5 + val ne0 = new NormalEquation(k) + .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) + .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) + .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 0) // addImplicit doesn't increase the count. + // NumPy code that computes the expected values: + // alpha = 0.5 + // A = np.matrix("-5 -4; -2 -1; 1 2") + // b = np.matrix("-3; 0; 3") + // b1 = b > 0 + // c = 1.0 + alpha * np.abs(b) + // C = np.diag(c.A1) + // I = np.eye(3) + // ata = A.transpose() * (C - I) * A + // atb = A.transpose() * C * b1 + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) + } + + test("CholeskySolver") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 4.0f) + .add(Array(1.0f, 3.0f), 9.0f) + .add(Array(1.0f, 4.0f), 16.0f) + val ne1 = new NormalEquation(k) + .merge(ne0) + + val chol = new CholeskySolver + val x0 = chol.solve(ne0, 0.0).map(_.toDouble) + // NumPy code that computes the expected solution: + // A = np.matrix("1 2; 1 3; 1 4") + // b = b = np.matrix("3; 6") + // x0 = np.linalg.lstsq(A, b)[0] + assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) + + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + + val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + // NumPy code that computes the expected solution, where lambda is scaled by n: + // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) + } + + test("RatingBlockBuilder") { + val emptyBuilder = new RatingBlockBuilder() + assert(emptyBuilder.size === 0) + val emptyBlock = emptyBuilder.build() + assert(emptyBlock.srcIds.isEmpty) + assert(emptyBlock.dstIds.isEmpty) + assert(emptyBlock.ratings.isEmpty) + + val builder0 = new RatingBlockBuilder() + .add(Rating(0, 1, 2.0f)) + .add(Rating(3, 4, 5.0f)) + assert(builder0.size === 2) + val builder1 = new RatingBlockBuilder() + .add(Rating(6, 7, 8.0f)) + .merge(builder0.build()) + assert(builder1.size === 3) + val block = builder1.build() + val ratings = Seq.tabulate(block.size) { i => + (block.srcIds(i), block.dstIds(i), block.ratings(i)) + }.toSet + assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f))) + } + + test("UncompressedInBlock") { + val encoder = new LocalIndexEncoder(10) + val uncompressed = new UncompressedInBlockBuilder(encoder) + .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) + .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) + .build() + assert(uncompressed.size === 5) + val records = Seq.tabulate(uncompressed.size) { i => + val dstEncodedIndex = uncompressed.dstEncodedIndices(i) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i)) + }.toSet + val expected = + Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f)) + assert(records === expected) + + val compressed = uncompressed.compress() + assert(compressed.size === 5) + assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3)) + assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) + var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] + var i = 0 + while (i < compressed.srcIds.size) { + var j = compressed.dstPtrs(i) + while (j < compressed.dstPtrs(i + 1)) { + val dstEncodedIndex = compressed.dstEncodedIndices(j) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j))) + j += 1 + } + i += 1 + } + assert(decompressed.toSet === expected) + } + + /** + * Generates an explicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genExplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val x = random.nextDouble() + if (x < totalFraction) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Test ALS using the given training/test splits and parameters. + * @param training training dataset + * @param test test dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @param targetRMSE target test RMSE + */ + def testALS( + training: RDD[Rating], + test: RDD[Rating], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3, + targetRMSE: Double = 0.05): Unit = { + val sqlContext = this.sqlContext + import sqlContext.createSchemaRDD + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + val alpha = als.getAlpha + val model = als.fit(training) + val predictions = model.transform(test) + .select("rating", "prediction") + .map { case Row(rating: Float, prediction: Float) => + (rating.toDouble, prediction.toDouble) + } + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE + // with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val mse = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + }.mean() + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) + } + + test("exact rank-1 matrix") { + val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1) + testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001) + testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001) + } + + test("approximate rank-1 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01) + testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02) + testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02) + } + + test("approximate rank-2 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03) + } + + test("different block settings") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) + } + } + + test("more blocks than ratings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numItemBlocks = 5, numUserBlocks = 5) + } + + test("implicit feedback") { + val (training, test) = + genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, + targetRMSE = 0.3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..74104fa7a681a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea7..b0b78acd6df16 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -169,16 +169,17 @@ class BLASSuite extends FunSuite { } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +189,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +207,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +252,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +272,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d868..2031032373971 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fdd..b1ebfde0e5e57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) + assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index f3b7bfda788fa..e9fc37e000526 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param bulkPredict flag to test bulk predicition * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 0000000000000..92b498580af03 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c125..55e963977b54f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } + + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + } diff --git a/pom.xml b/pom.xml index b993391b15042..05cb3797fc55b 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 1.0.4 2.4.1 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 127973b658190..e750fed7448cd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,6 +52,20 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( @@ -81,7 +95,20 @@ object MimaExcludes { ) ++ Seq( // SPARK-5166 Spark SQL API stabilization ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") ) ++ Seq( // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( @@ -90,6 +117,10 @@ object MimaExcludes { // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") ) case v if v.startsWith("1.2") => diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4c..568e21f3803bf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..a0a028446d5fd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -111,10 +111,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4977400ac1c05..f4cfe4845dc20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -141,6 +141,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1990323249cf6..7d7550c854b2f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,15 +20,19 @@ - L{SQLContext} Main entry point for SQL functionality. - - L{SchemaRDD} + - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedDataFrame} + - L{Column} + Column is a DataFrame with a single column. - L{Row} A Row of data returned by a Spark SQL query. - L{HiveContext} Main entry point for accessing data stored in Apache Hive.. """ +import sys import itertools import decimal import datetime @@ -36,6 +40,9 @@ import warnings import json import re +import random +import os +from tempfile import NamedTemporaryFile from array import array from operator import itemgetter from itertools import imap @@ -43,6 +50,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter +from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer @@ -54,7 +62,8 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SchemaRDD"] class DataType(object): @@ -1171,7 +1180,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1198,7 +1207,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... @@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None): ... [Row(field1=1, field2="row1"), ... Row(field1=2, field2="row2"), ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') >>> NestedRow = Row("f1", "f2") >>> nestedRdd1 = sc.parallelize([ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> nestedRdd2 = sc.parallelize([ ... NestedRow([[1, 2], [2, 3]], [1, 2]), ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] >>> from collections import namedtuple @@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None): ... [CustomRow(field1=1, field2="row1"), ... CustomRow(field1=2, field2="row2"), ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") first = rdd.first() if not first: @@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema): >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import date, datetime @@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> srdd.registerTempTable("table2") + >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + @@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema): >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema): rdd = rdd.map(converter) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName): Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) else: - raise ValueError("Can only register SchemaRDD as table") + raise ValueError("Can only register DataFrame as table") def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. + """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) + jdf = self._ssql_ctx.parquetFile(path) + return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a - L{SchemaRDD}. + L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) @@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) + df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. + """Loads an RDD storing one JSON object per string as a L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Otherwise, it samples the dataset with ratio `samplingRatio` to determine the schema. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) @@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] >>> sqlCtx.jsonRDD(sc.parallelize(['{}', @@ -1615,33 +1624,33 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. + """Return a L{DataFrame} representing the result of the given query. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. + """Returns the specified table as a L{DataFrame}. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return DataFrame(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1707,7 +1716,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1799,111 +1808,119 @@ def inherit_doc(cls): return cls -@inherit_doc -class SchemaRDD(RDD): +class DataFrame(object): - """An RDD of L{Row} objects that has an associated schema. + """A collection of rows that have the same columns. - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. + people = sqlContext.parquetFile("...") - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: [[DataFrame]], [[Column]]. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ - def __init__(self, jschema_rdd, sql_ctx): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None + self._sc = sql_ctx and sql_ctx._sc self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. + def rdd(self): + """Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row`s. """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd def limit(self, num): """Limit the result count to the number specified. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.limit(2).collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() + >>> df.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( "SELECT * from table1") + >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] True """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. + a DataFrame using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. + that was used to create this DataFrame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.registerTempTable("test") + >>> df2 = sqlCtx.sql("select * from test") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - self._jschema_rdd.registerTempTable(name) + self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use registerTempTable() instead""" @@ -1911,62 +1928,61 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. + """Inserts the contents of this DataFrame into the specified table. Optionally overwriting any existing data. """ - self._jschema_rdd.insertInto(tableName, overwrite) + self._jdf.insertInto(tableName, overwrite) def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by + """Returns the schema of this DataFrame (represented by a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() + return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): """Prints out the schema in the tree format.""" - print self.schemaString() + print (self._jdf.schema().treeString()) def count(self): """Return the number of elements in this RDD. Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, + leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.count() 3L - >>> srdd.count() == srdd.map(lambda x: x).count() + >>> df.count() == df.map(lambda x: x).count() True """ - return self._jschema_rdd.count() + return self._jdf.count() def collect(self): - """Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows. Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect() [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + return [cls(r) for r in rs] def take(self, num): """Take the first num rows of the RDD. @@ -1974,130 +1990,555 @@ def take(self, num): Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.take(2) [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return self.limit(num).collect() - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. + return self.rdd.map(f) - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 + def mapPartitions(self, f, preservesPartitioning=False): """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() + Return a new RDD by applying a function to each partition. - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - self._jschema_rdd.cache() + self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ self.is_cached = False - self._jschema_rdd.unpersist(blocking) + self._jdf.unpersist(blocking) return self - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() + @property + def dtypes(self): + """Return all column names and their data types as a list. + """ + return [(f.name, str(f.dataType)) for f in self.schema().fields] + + @property + def columns(self): + """ Return all column names as a list. + """ + return [f.name for f in self.schema().fields] + + def show(self): + raise NotImplemented + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + df1.join(df2, df1.key == df2.key, "outer") + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, + `semijoin`. + """ + if joinType is None: + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + jdf = self._jdf.join(other._jdf, joinExprs) else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.join(other._jdf, joinExprs, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new [[DataFrame]] sorted by the specified column, + in ascending column. + + :param cols: The columns or expressions used for sorting + """ + if not cols: + raise ValueError("should sort by at least one column") + for i, c in enumerate(cols): + if isinstance(c, basestring): + cols[i] = Column(c) + jcols = [c._jc for c in cols] + jdf = self._jdf.join(*jcols) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def tail(self): + raise NotImplemented + + def __getitem__(self, item): + if isinstance(item, basestring): + return Column(self._jdf.apply(item)) + + # TODO projection + raise IndexError + + def __getattr__(self, name): + """ Return the column by given name """ + if isinstance(name, basestring): + return Column(self._jdf.apply(name)) + raise AttributeError + + def As(self, name): + """ Alias the current DataFrame """ + return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions.:: + + df.select() + df.select('colA', 'colB') + df.select(df.colA, df.colB + 1) - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) + """ + if not cols: + cols = ["*"] + if isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only intersect with another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) + return DataFrame(jdf, self.sql_ctx) - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + def filter(self, condition): + """ Filtering rows using the given condition:: - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) + df.filter(df.age > 15) + df.where(df.age > 15) + + """ + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the [[DataFrame]] using the specified columns, + so we can run aggregation on them. See :class:`GroupedDataFrame` + for all the available aggregate functions:: + + df.groupBy(df.department).avg() + df.groupBy("department", "gender").agg({ + "salary": "avg", + "age": "max", + }) + """ + if cols and isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only subtract another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) + return GroupedDataFrame(jdf, self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): + def agg(self, *exprs): + """ Aggregate on the entire [[DataFrame]] without groups + (shorthand for df.groupBy.agg()):: + + df.agg({"age": "max", "salary": "avg"}) """ - Return a sampled subset of this SchemaRDD. + return self.groupBy().agg(*exprs) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. + def intersect(self, other): + """ Return a new [[DataFrame]] containing rows only in + both this frame and another frame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + This is equivalent to `INTERSECT` in SQL. """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def Except(self, other): + """ Return a new [[DataFrame]] containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ Return a new DataFrame by sampling a fraction of rows. """ + if seed is None: + jdf = self._jdf.sample(withReplacement, fraction) + else: + jdf = self._jdf.sample(withReplacement, fraction, seed) + return DataFrame(jdf, self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new [[DataFrame]] by adding a column. """ + return self.select('*', col.As(colName)) + + def removeColumn(self, colName): + raise NotImplemented + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedDataFrame(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to agregate methods. + """ + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" + jdf = self._jdf.agg(*exprs) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +SCALA_METHOD_MAPPINGS = { + '=': '$eq', + '>': '$greater', + '<': '$less', + '+': '$plus', + '-': '$minus', + '*': '$times', + '/': '$div', + '!': '$bang', + '@': '$at', + '#': '$hash', + '%': '$percent', + '^': '$up', + '&': '$amp', + '~': '$tilde', + '?': '$qmark', + '|': '$bar', + '\\': '$bslash', + ':': '$colon', +} + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Literal.apply(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Column(name) + + +def _scalaMethod(name): + """ Translate operators into methodName in Scala + + For example: + >>> _scalaMethod('+') + '$plus' + >>> _scalaMethod('>=') + '$greater$eq' + >>> _scalaMethod('cast') + 'cast' + """ + return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) + + +def _unary_op(name): + """ Create a method for given unary operator """ + def _(self): + return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + return _ + + +def _bin_op(name): + """ Create a method for given binary operator """ + def _(self, other): + if isinstance(other, Column): + jc = other._jc + else: + jc = _create_column_from_literal(other) + return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + return _ + + +def _reverse_op(name): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), + self._jdf, self.sql_ctx) + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by: + {{{ + // 1. Select a column out of a DataFrame + df.colName + df["colName"] + + // 2. Create from an expression + df["colName"] + 1 + }}} + """ + + def __init__(self, jc, jdf=None, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jdf, sql_ctx) + + # arithmetic operators + __neg__ = _unary_op("unary_-") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+") + __rsub__ = _reverse_op("-") + __rmul__ = _bin_op("*") + __rdiv__ = _reverse_op("/") + __rmod__ = _reverse_op("%") + __abs__ = _unary_op("abs") + abs = _unary_op("abs") + sqrt = _unary_op("sqrt") + + # logistic operators + __eq__ = _bin_op("===") + __ne__ = _bin_op("!==") + __lt__ = _bin_op("<") + __le__ = _bin_op("<=") + __ge__ = _bin_op(">=") + __gt__ = _bin_op(">") + # `and`, `or`, `not` cannot be overloaded in Python + And = _bin_op('&&') + Or = _bin_op('||') + Not = _unary_op('unary_!') + + # bitwise operators + __and__ = _bin_op("&") + __or__ = _bin_op("|") + __invert__ = _unary_op("unary_~") + __xor__ = _bin_op("^") + # __lshift__ = _bin_op("<<") + # __rshift__ = _bin_op(">>") + __rand__ = _bin_op("&") + __ror__ = _bin_op("|") + __rxor__ = _bin_op("^") + # __rlshift__ = _reverse_op("<<") + # __rrshift__ = _reverse_op(">>") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + # __getattr__ = _bin_op("getField") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + upper = _unary_op("upper") + lower = _unary_op("lower") + + def substr(self, startPos, pos): + if type(startPos) != type(pos): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + + jc = self._jc.substr(startPos, pos) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, pos._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self._jdf, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull") + isNotNull = _unary_op("isNotNull") + + # `as` is keyword + def As(self, alias): + return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + + def cast(self, dataType): + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + jdt = ssql_ctx.parseDataType(dataType.json()) + return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + + +def _aggregate_func(name): + """ Creat a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + # FIXME: can not access dsl.min/max ... + jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol) + return Column(jc) + return staticmethod(_) + + +class Aggregator(object): + """ + A collections of builtin aggregators + """ + max = _aggregate_func("max") + min = _aggregate_func("min") + avg = mean = _aggregate_func("mean") + sum = _aggregate_func("sum") + first = _aggregate_func("first") + last = _aggregate_func("last") + count = _aggregate_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..e8e207af462de 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -806,6 +806,9 @@ def tearDownClass(cls): def setUp(self): self.sqlCtx = SQLContext(self.sc) + self.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(self.testData) + self.df = self.sqlCtx.inferSchema(rdd) def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) @@ -821,7 +824,7 @@ def test_udf2(self): def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.inferSchema(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -839,68 +842,51 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() - - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() + df = self.sqlCtx.inferSchema(rdd) + row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = srdd.map(lambda x: x.l).first() + l = df.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = srdd.map(lambda x: x.d).first() + d = df.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = srdd.map(lambda x: x.d["key"]).first() + row = df.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -908,26 +894,26 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -935,9 +921,9 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -945,12 +931,12 @@ def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -959,21 +945,52 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.inferSchema(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbit)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg + # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 22941edef2d46..4c5fb3f45bf49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] { .toSet plan transform { - case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3035d934ff9f8..f388cd5972bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression { * For example the SQL expression "1 + 1 AS a" could be represented as follows: * Alias(Add(Literal(1), Literal(1), "a")() * + * Note that exprId and qualifiers are in a separate parameter list because + * we only pattern match on child and name. + * * @param child the computation being performed * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 613f4bb09daf5..5dc0539caec24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,9 +17,24 @@ package org.apache.spark.sql.catalyst.plans +object JoinType { + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + case "inner" => Inner + case "outer" | "full" | "fullouter" => FullOuter + case "leftouter" | "left" => LeftOuter + case "rightouter" | "right" => RightOuter + case "leftsemi" => LeftSemi + } +} + sealed abstract class JoinType + case object Inner extends JoinType + case object LeftOuter extends JoinType + case object RightOuter extends JoinType + case object FullOuter extends JoinType + case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index 19769986ef58c..d90af45b375e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { - def apply(output: Attribute*) = - new LocalRelation(output) + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( + StructType(output1 +: output).toAttributes + ) } case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index e715d9434a2ab..bc22f688338b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -80,7 +80,7 @@ private[sql] trait CacheManager { * the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: SchemaRDD, + query: DataFrame, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -100,7 +100,7 @@ private[sql] trait CacheManager { } /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { + private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -110,7 +110,7 @@ private[sql] trait CacheManager { /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ private[sql] def tryUncacheQuery( - query: SchemaRDD, + query: DataFrame, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,7 +123,7 @@ private[sql] trait CacheManager { } /** Optionally returns cached data for the given SchemaRDD */ - private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 0000000000000..7fc8347428df4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,528 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.types._ + + +object Column { + def unapply(col: Column): Option[Expression] = Some(col.expr) + + def apply(colName: String): Column = new Column(colName) +} + + +/** + * A column in a [[DataFrame]]. + * + * `Column` instances can be created by: + * {{{ + * // 1. Select a column out of a DataFrame + * df("colName") + * + * // 2. Create a literal expression + * Literal(1) + * + * // 3. Create new columns from + * }}} + * + */ +// TODO: Improve documentation. +class Column( + sqlContext: Option[SQLContext], + plan: Option[LogicalPlan], + val expr: Expression) + extends DataFrame(sqlContext, plan) with ExpressionApi { + + /** Turn a Catalyst expression into a `Column`. */ + protected[sql] def this(expr: Expression) = this(None, None, expr) + + /** + * Create a new `Column` expression based on a column or attribute name. + * The resolution of this is the same as SQL. For example: + * + * - "colName" becomes an expression selecting the column named "colName". + * - "*" becomes an expression selecting all columns. + * - "df.*" becomes an expression selecting all columns in data frame "df". + */ + def this(name: String) = this(name match { + case "*" => Star(None) + case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined + + /** + * An implicit conversion function internal to this class. This function creates a new Column + * based on an expression. If the expression itself is not named, it aliases the expression + * by calling it "col". + */ + private[this] implicit def toColumn(expr: Expression): Column = { + val projectedPlan = plan.map { p => + Project(Seq(expr match { + case named: NamedExpression => named + case unnamed: Expression => Alias(unnamed, "col")() + }), p) + } + new Column(sqlContext, projectedPlan, expr) + } + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df("amount") ) + * }}} + */ + override def unary_- : Column = UnaryMinus(expr) + + /** + * Bitwise NOT. + * {{{ + * // Select the flags column and negate every bit. + * df.select( ~df("flags") ) + * }}} + */ + override def unary_~ : Column = BitwiseNot(expr) + + /** + * Invert a boolean expression, i.e. NOT. + * {{ + * // Select rows that are not active (isActive === false) + * df.select( !df("isActive") ) + * }} + */ + override def unary_! : Column = Not(expr) + + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def === (other: Column): Column = EqualTo(expr, other.expr) + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def === (literal: Any): Column = this === Literal.anyToLiteral(literal) + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def equalTo(other: Column): Column = this === other + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def equalTo(literal: Any): Column = this === literal + + /** + * Inequality test with an expression. + * {{{ + * // The following two both select rows in which colA does not equal colB. + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * }}} + */ + override def !== (other: Column): Column = Not(EqualTo(expr, other.expr)) + + /** + * Inequality test with a literal value. + * {{{ + * // The following two both select rows in which colA does not equal equal 15. + * df.select( df("colA") !== 15 ) + * df.select( !(df("colA") === 15) ) + * }}} + */ + override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal) + + /** + * Greater than an expression. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > Literal(21) ) + * }}} + */ + override def > (other: Column): Column = GreaterThan(expr, other.expr) + + /** + * Greater than a literal value. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > 21 ) + * }}} + */ + override def > (literal: Any): Column = this > Literal.anyToLiteral(literal) + + /** + * Less than an expression. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < Literal(21) ) + * }}} + */ + override def < (other: Column): Column = LessThan(expr, other.expr) + + /** + * Less than a literal value. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < 21 ) + * }}} + */ + override def < (literal: Any): Column = this < Literal.anyToLiteral(literal) + + /** + * Less than or equal to an expression. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= Literal(21) ) + * }}} + */ + override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr) + + /** + * Less than or equal to a literal value. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * }}} + */ + override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal) + + /** + * Greater than or equal to an expression. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= Literal(21) ) + * }}} + */ + override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr) + + /** + * Greater than or equal to a literal value. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * }}} + */ + override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal) + + /** + * Equality test with an expression that is safe for null values. + */ + override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr) + + /** + * Equality test with a literal value that is safe for null values. + */ + override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal) + + /** + * True if the current expression is null. + */ + override def isNull: Column = IsNull(expr) + + /** + * True if the current expression is NOT null. + */ + override def isNotNull: Column = IsNotNull(expr) + + /** + * Boolean OR with an expression. + * {{{ + * // The following selects people that are in school or employed. + * people.select( people("inSchool") || people("isEmployed") ) + * }}} + */ + override def || (other: Column): Column = Or(expr, other.expr) + + /** + * Boolean OR with a literal value. + * {{{ + * // The following selects everything. + * people.select( people("inSchool") || true ) + * }}} + */ + override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal) + + /** + * Boolean AND with an expression. + * {{{ + * // The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * }}} + */ + override def && (other: Column): Column = And(expr, other.expr) + + /** + * Boolean AND with a literal value. + * {{{ + * // The following selects people that are in school. + * people.select( people("inSchool") && true ) + * }}} + */ + override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal) + + /** + * Bitwise AND with an expression. + */ + override def & (other: Column): Column = BitwiseAnd(expr, other.expr) + + /** + * Bitwise AND with a literal value. + */ + override def & (literal: Any): Column = this & Literal.anyToLiteral(literal) + + /** + * Bitwise OR with an expression. + */ + override def | (other: Column): Column = BitwiseOr(expr, other.expr) + + /** + * Bitwise OR with a literal value. + */ + override def | (literal: Any): Column = this | Literal.anyToLiteral(literal) + + /** + * Bitwise XOR with an expression. + */ + override def ^ (other: Column): Column = BitwiseXor(expr, other.expr) + + /** + * Bitwise XOR with a literal value. + */ + override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * }}} + */ + override def + (other: Column): Column = Add(expr, other.expr) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and 10. + * people.select( people("height") + 10 ) + * }}} + */ + override def + (literal: Any): Column = this + Literal.anyToLiteral(literal) + + /** + * Subtraction. Substract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * }}} + */ + override def - (other: Column): Column = Subtract(expr, other.expr) + + /** + * Subtraction. Substract a literal value from this expression. + * {{{ + * // The following selects a person's height and substract it by 10. + * people.select( people("height") - 10 ) + * }}} + */ + override def - (literal: Any): Column = this - Literal.anyToLiteral(literal) + + /** + * Multiply this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * }}} + */ + override def * (other: Column): Column = Multiply(expr, other.expr) + + /** + * Multiply this expression and a literal value. + * {{{ + * // The following multiplies a person's height by 10. + * people.select( people("height") * 10 ) + * }}} + */ + override def * (literal: Any): Column = this * Literal.anyToLiteral(literal) + + /** + * Divide this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * }}} + */ + override def / (other: Column): Column = Divide(expr, other.expr) + + /** + * Divide this expression by a literal value. + * {{{ + * // The following divides a person's height by 10. + * people.select( people("height") / 10 ) + * }}} + */ + override def / (literal: Any): Column = this / Literal.anyToLiteral(literal) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (other: Column): Column = Remainder(expr, other.expr) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (literal: Any): Column = this % Literal.anyToLiteral(literal) + + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + */ + @scala.annotation.varargs + override def in(list: Column*): Column = In(expr, list.map(_.expr)) + + override def like(other: Column): Column = Like(expr, other.expr) + + override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal)) + + override def rlike(other: Column): Column = RLike(expr, other.expr) + + override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal)) + + + override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) + + override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr) + + override def getField(fieldName: String): Column = GetField(expr, fieldName) + + + override def substr(startPos: Column, len: Column): Column = + Substring(expr, startPos.expr, len.expr) + + override def substr(startPos: Int, len: Int): Column = + this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len)) + + override def contains(other: Column): Column = Contains(expr, other.expr) + + override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal)) + + + override def startsWith(other: Column): Column = StartsWith(expr, other.expr) + + override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal)) + + override def endsWith(other: Column): Column = EndsWith(expr, other.expr) + + override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal)) + + override def as(alias: String): Column = Alias(expr, alias)() + + override def cast(to: DataType): Column = Cast(expr, to) + + override def desc: Column = SortOrder(expr, Descending) + + override def asc: Column = SortOrder(expr, Ascending) +} + + +class ColumnName(name: String) extends Column(name) { + + /** Creates a new AttributeReference of type boolean */ + def boolean: StructField = StructField(name, BooleanType) + + /** Creates a new AttributeReference of type byte */ + def byte: StructField = StructField(name, ByteType) + + /** Creates a new AttributeReference of type short */ + def short: StructField = StructField(name, ShortType) + + /** Creates a new AttributeReference of type int */ + def int: StructField = StructField(name, IntegerType) + + /** Creates a new AttributeReference of type long */ + def long: StructField = StructField(name, LongType) + + /** Creates a new AttributeReference of type float */ + def float: StructField = StructField(name, FloatType) + + /** Creates a new AttributeReference of type double */ + def double: StructField = StructField(name, DoubleType) + + /** Creates a new AttributeReference of type string */ + def string: StructField = StructField(name, StringType) + + /** Creates a new AttributeReference of type date */ + def date: StructField = StructField(name, DateType) + + /** Creates a new AttributeReference of type decimal */ + def decimal: StructField = StructField(name, DecimalType.Unlimited) + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** Creates a new AttributeReference of type timestamp */ + def timestamp: StructField = StructField(name, TimestampType) + + /** Creates a new AttributeReference of type binary */ + def binary: StructField = StructField(name, BinaryType) + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala new file mode 100644 index 0000000000000..d0bb3640f8c1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -0,0 +1,596 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import java.util.{ArrayList, List => JList} + +import com.fasterxml.jackson.core.JsonFactory +import net.razorvine.pickle.Pickler + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.util.Utils + + +/** + * A collection of rows that have the same columns. + * + * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using + * various functions in [[SQLContext]]. + * {{{ + * val people = sqlContext.parquetFile("...") + * }}} + * + * Once created, it can be manipulated using the various domain-specific-language (DSL) functions + * defined in: [[DataFrame]] (this class), [[Column]], and [[dsl]] for Scala DSL. + * + * To select a column from the data frame, use the apply method: + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.apply("age") // in Java + * }}} + * + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * }} + * + * A more concrete example: + * {{{ + * // To create DataFrame using SQLContext + * val people = sqlContext.parquetFile("...") + * val department = sqlContext.parquetFile("...") + * + * people.filter("age" > 30) + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + */ +// TODO: Improve documentation. +class DataFrame protected[sql]( + val sqlContext: SQLContext, + private val baseLogicalPlan: LogicalPlan, + operatorsEnabled: Boolean) + extends DataFrameSpecificApi with RDDApi[Row] { + + protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) = + this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined) + + protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true) + + @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + baseLogicalPlan + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + private[this] implicit def toDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan, true) + } + + /** Return the list of numeric columns, useful for doing aggregation. */ + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + logicalPlan.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + /** Resolve a column name into a Catalyst [[NamedExpression]]. */ + protected[sql] def resolve(colName: String): NamedExpression = { + logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse( + throw new RuntimeException(s"""Cannot resolve column name "$colName"""")) + } + + /** Left here for compatibility reasons. */ + @deprecated("1.3.0", "use toDataFrame") + def toSchemaRDD: DataFrame = this + + /** + * Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. + */ + def toDF: DataFrame = this + + /** Return the schema of this [[DataFrame]]. */ + override def schema: StructType = queryExecution.analyzed.schema + + /** Return all column names and their data types as an array. */ + override def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** Return all column names as an array. */ + override def columns: Array[String] = schema.fields.map(_.name) + + /** Print the schema to the console in a nice tree format. */ + override def printSchema(): Unit = println(schema.treeString) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + */ + override def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + */ + override def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) + } + + /** + * Join with another [[DataFrame]], usin g the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + */ + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + /** + * Return a new [[DataFrame]] sorted by the specified column, in ascending column. + * {{{ + * // The following 3 are equivalent + * df.sort("sortcol") + * df.sort($"sortcol") + * df.sort($"sortcol".asc) + * }}} + */ + override def sort(colName: String): DataFrame = { + Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan) + } + + /** + * Return a new [[DataFrame]] sorted by the given expressions. For example: + * {{{ + * df.sort($"col1", $"col2".desc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + /** + * Return a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { + sort(sortExpr, sortExprs :_*) + } + + /** + * Selecting a single column and return it as a [[Column]]. + */ + override def apply(colName: String): Column = { + val expr = resolve(colName) + new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) + } + + /** + * Selecting a set of expressions, wrapped in a Product. + * {{{ + * // The following two are equivalent: + * df.apply(($"colA", $"colB" + 1)) + * df.select($"colA", $"colB" + 1) + * }}} + */ + override def apply(projection: Product): DataFrame = { + require(projection.productArity >= 1) + select(projection.productIterator.map { + case c: Column => c + case o: Any => new Column(Some(sqlContext), None, LiteralExpr(o)) + }.toSeq :_*) + } + + /** + * Alias the current [[DataFrame]]. + */ + override def as(name: String): DataFrame = Subquery(name, logicalPlan) + + /** + * Selecting a set of expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + */ + @scala.annotation.varargs + override def select(cols: Column*): DataFrame = { + val exprs = cols.zipWithIndex.map { + case (Column(expr: NamedExpression), _) => + expr + case (Column(expr: Expression), _) => + Alias(expr, expr.toString)() + } + Project(exprs.toSeq, logicalPlan) + } + + /** + * Selecting a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * df.select("colA", "colB") + * df.select($"colA", $"colB") + * }}} + */ + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(new Column(_)) :_*) + } + + /** + * Filtering rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def filter(condition: Column): DataFrame = { + Filter(condition.expr, logicalPlan) + } + + /** + * Filtering rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def where(condition: Column): DataFrame = filter(condition) + + /** + * Filtering rows using the given condition. This is a shorthand meant for Scala. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def apply(condition: Column): DataFrame = filter(condition) + + /** + * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(cols: Column*): GroupedDataFrame = { + new GroupedDataFrame(this, cols.map(_.expr)) + } + + /** + * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + val colNames: Seq[String] = col1 +: cols + new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + } + + /** + * Aggregate on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregate on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) + * }} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + + /** + * Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + */ + override def limit(n: Int): DataFrame = Limit(LiteralExpr(n), logicalPlan) + + /** + * Return a new [[DataFrame]] containing union of rows in this frame and another frame. + * This is equivalent to `UNION ALL` in SQL. + */ + override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] containing rows only in both this frame and another frame. + * This is equivalent to `INTERSECT` in SQL. + */ + override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] containing rows in this frame but not in another frame. + * This is equivalent to `EXCEPT` in SQL. + */ + override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + + /** + * Return a new [[DataFrame]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + */ + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + /** + * Return a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + */ + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Return a new [[DataFrame]] by adding a column. + */ + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + /** + * Return the first `n` rows. + */ + override def head(n: Int): Array[Row] = limit(n).collect() + + /** + * Return the first row. + */ + override def head(): Row = head(1).head + + /** + * Return the first row. Alias for head(). + */ + override def first(): Row = head() + + override def map[R: ClassTag](f: Row => R): RDD[R] = { + rdd.map(f) + } + + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Return the first `n` rows in the [[DataFrame]]. + */ + override def take(n: Int): Array[Row] = head(n) + + /** + * Return an array that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collect(): Array[Row] = rdd.collect() + + /** + * Return a Java list that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + /** + * Return the number of rows in the [[DataFrame]]. + */ + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + /** + * Return a new [[DataFrame]] that has exactly `numPartitions` partitions. + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } + + override def persist(): this.type = { + sqlContext.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheQuery(this, None, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.tryUncacheQuery(this, blocking) + this + } + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + /** + * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s. + */ + override def rdd: RDD[Row] = { + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * + * @group schema + */ + override def registerTempTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. + * Files that are written out using this method can be read back in as a [[DataFrame]] + * using the `parquetFile` function in [[SQLContext]]. + */ + override def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame. This will fail if the table already + * exists. + * + * Note that this currently only works with DataFrame that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + override def saveAsTable(tableName: String): Unit = { + sqlContext.executePlan( + CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + */ + @Experimental + override def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } + + /** + * Return the content of the [[DataFrame]] as a RDD of JSON strings. + */ + override def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + /** + * A helpful function for Py4j, convert a list of Column to an array + */ + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray + } + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala new file mode 100644 index 0000000000000..1f1e9bd9899f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate + + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + */ +class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedDataFrameApi { + + private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + val namedGroupingExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + new DataFrame(df.sqlContext, + Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + } + + private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { + df.numericColumns.map { c => + val a = f(c) + Alias(a, a.toString)() + } + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + expr.toLowerCase match { + case "avg" | "average" | "mean" => Average + case "max" => Max + case "min" => Min + case "sum" => Sum + case "count" | "size" => Count + } + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + override def agg(exprs: Map[String, String]): DataFrame = { + exprs.map { case (colName, expr) => + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.toString)() + }.toSeq + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. + * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]]. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import org.apache.spark.sql.dsl._ + * df.groupBy("department").agg(max($"age"), sum($"expense")) + * }}} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = { + val aggExprs = (expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + + new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + + /** Count the number of rows for each group. */ + override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + */ + override def mean(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the max value for each numeric columns for each group. + */ + override def max(): DataFrame = aggregateNumericColumns(Max) + + /** + * Compute the mean value for each numeric columns for each group. + */ + override def avg(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the min value for each numeric column for each group. + */ + override def min(): DataFrame = aggregateNumericColumns(Min) + + /** + * Compute the sum for each numeric columns for each group. + */ + override def sum(): DataFrame = aggregateNumericColumns(Sum) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala new file mode 100644 index 0000000000000..08cd4d0f3f009 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.types._ + +object Literal { + + /** Return a new boolean literal. */ + def apply(literal: Boolean): Column = new Column(LiteralExpr(literal)) + + /** Return a new byte literal. */ + def apply(literal: Byte): Column = new Column(LiteralExpr(literal)) + + /** Return a new short literal. */ + def apply(literal: Short): Column = new Column(LiteralExpr(literal)) + + /** Return a new int literal. */ + def apply(literal: Int): Column = new Column(LiteralExpr(literal)) + + /** Return a new long literal. */ + def apply(literal: Long): Column = new Column(LiteralExpr(literal)) + + /** Return a new float literal. */ + def apply(literal: Float): Column = new Column(LiteralExpr(literal)) + + /** Return a new double literal. */ + def apply(literal: Double): Column = new Column(LiteralExpr(literal)) + + /** Return a new string literal. */ + def apply(literal: String): Column = new Column(LiteralExpr(literal)) + + /** Return a new decimal literal. */ + def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal)) + + /** Return a new decimal literal. */ + def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal)) + + /** Return a new timestamp literal. */ + def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal)) + + /** Return a new date literal. */ + def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal)) + + /** Return a new binary (byte array) literal. */ + def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal)) + + /** Return a new null literal. */ + def apply(literal: Null): Column = new Column(LiteralExpr(null)) + + /** + * Return a Column expression representing the literal value. Throws an exception if the + * data type is not supported by SparkSQL. + */ + protected[sql] def anyToLiteral(literal: Any): Column = { + // If the literal is a symbol, convert it into a Column. + if (literal.isInstanceOf[Symbol]) { + return dsl.symbolToColumn(literal.asInstanceOf[Symbol]) + } + + val literalExpr = literal match { + case v: Int => LiteralExpr(v, IntegerType) + case v: Long => LiteralExpr(v, LongType) + case v: Double => LiteralExpr(v, DoubleType) + case v: Float => LiteralExpr(v, FloatType) + case v: Byte => LiteralExpr(v, ByteType) + case v: Short => LiteralExpr(v, ShortType) + case v: String => LiteralExpr(v, StringType) + case v: Boolean => LiteralExpr(v, BooleanType) + case v: BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited) + case v: java.math.BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited) + case v: Decimal => LiteralExpr(v, DecimalType.Unlimited) + case v: java.sql.Timestamp => LiteralExpr(v, TimestampType) + case v: java.sql.Date => LiteralExpr(v, DateType) + case v: Array[Byte] => LiteralExpr(v, BinaryType) + case null => LiteralExpr(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) + } + new Column(literalExpr) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a22968cc7807..5030e689c36ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -43,7 +42,7 @@ import org.apache.spark.util.Utils /** * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]] * objects and the execution of SQL queries. * * @groupname userf Spark SQL Functions @@ -53,7 +52,6 @@ import org.apache.spark.util.Utils class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with CacheManager - with ExpressionConversions with Serializable { self => @@ -111,8 +109,8 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + + protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) @@ -124,24 +122,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** - * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. */ - def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - new SchemaRDD(this, LogicalRelation(baseRelation)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): DataFrame = { + new DataFrame(this, LogicalRelation(baseRelation)) } /** * :: DeveloperApi :: - * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: @@ -170,11 +168,11 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new SchemaRDD(this, logicalPlan) + new DataFrame(this, logicalPlan) } /** @@ -183,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -201,7 +199,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ) : Row } } - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -210,35 +208,35 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { applySchema(rdd.rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[SchemaRDD]]. + * Loads a Parquet file, returning the result as a [[DataFrame]]. * * @group userf */ - def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + def parquetFile(path: String): DataFrame = + new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** - * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) /** * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonFile(path: String, schema: StructType): SchemaRDD = { + def jsonFile(path: String, schema: StructType): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, schema) } @@ -247,29 +245,29 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, samplingRatio) } /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[SchemaRDD]]. + * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( @@ -283,7 +281,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( @@ -298,8 +296,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) + def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), rdd.logicalPlan) } /** @@ -321,17 +319,17 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { + def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new SchemaRDD(this, parseSql(sqlText)) + new DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } /** Returns the specified table as a SchemaRDD */ - def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) + def table(tableName: String): DataFrame = + new DataFrame(this, catalog.lookupRelation(Seq(tableName))) /** * A collection of methods that are considered experimental, but can be used to hook into @@ -454,15 +452,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected abstract class QueryExecution { - def logical: LogicalPlan + protected class QueryExecution(val logical: LogicalPlan) { - lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData = useCachedData(analyzed) - lazy val optimizedPlan = optimizer(withCachedData) + lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) + lazy val withCachedData: LogicalPlan = useCachedData(analyzed) + lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... - lazy val sparkPlan = { + lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) planner(optimizedPlan).next() } @@ -512,7 +509,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schemaString: String): SchemaRDD = { + schemaString: String): DataFrame = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } @@ -522,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schema: StructType): SchemaRDD = { + schema: StructType): DataFrame = { def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -549,7 +546,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala deleted file mode 100644 index d1e21dffeb8c5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ /dev/null @@ -1,511 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import net.razorvine.pickle.Pickler - -import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{BooleanType, StructType} -import org.apache.spark.storage.StorageLevel - -/** - * :: AlphaComponent :: - * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, - * SchemaRDDs can be used in relational queries, as shown in the examples below. - * - * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD - * whose elements are scala case classes into a SchemaRDD. This conversion can also be done - * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. - * - * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] - * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. - * - * == SQL Queries == - * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once - * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. - * - * {{{ - * // One method for defining the schema of an RDD is to make a case class with the desired column - * // names and types. - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - * // Any RDD containing case classes can be registered as a table. The schema of the table is - * // automatically inferred using scala reflection. - * rdd.registerTempTable("records") - * - * val results: SchemaRDD = sql("SELECT * FROM records") - * }}} - * - * == Language Integrated Queries == - * - * {{{ - * - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) - * - * // Example of language integrated queries. - * rdd.where('key === 1).orderBy('value.asc).select('key).collect() - * }}} - * - * @groupname Query Language Integrated Queries - * @groupdesc Query Functions that create new queries from SchemaRDDs. The - * result of all query functions is also a SchemaRDD, allowing multiple operations to be - * chained using a builder pattern. - * @groupprio Query -2 - * @groupname schema SchemaRDD Functions - * @groupprio schema -1 - * @groupname Ungrouped Base RDD Functions - */ -@AlphaComponent -class SchemaRDD( - @transient val sqlContext: SQLContext, - @transient val baseLogicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - - def baseSchemaRDD = this - - // ========================================================================================= - // RDD functions: Copy the internal row representation so we present immutable data to users. - // ========================================================================================= - - override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) - - override def getPartitions: Array[Partition] = firstParent[Row].partitions - - override protected def getDependencies: Seq[Dependency[_]] = { - schema // Force reification of the schema so it is available on executors. - - List(new OneToOneDependency(queryExecution.toRdd)) - } - - /** - * Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - lazy val schema: StructType = queryExecution.analyzed.schema - - /** - * Returns a new RDD with each row transformed to a JSON string. - * - * @group schema - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - - // ======================================================================= - // Query DSL - // ======================================================================= - - /** - * Changes the output of this relation to the given expressions, similar to the `SELECT` clause - * in SQL. - * - * {{{ - * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) - * }}} - * - * @param exprs a set of logical expression that will be evaluated for each input row. - * - * @group Query - */ - def select(exprs: Expression*): SchemaRDD = { - val aliases = exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) - } - - /** - * Filters the output, only returning those rows where `condition` evaluates to true. - * - * {{{ - * schemaRDD.where('a === 'b) - * schemaRDD.where('a === 1) - * schemaRDD.where('a + 'b > 10) - * }}} - * - * @group Query - */ - def where(condition: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) - - /** - * Performs a relational join on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be joined with this one. - * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param on An optional condition for the join operation. This is equivalent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivalent to adding `where` clauses after the `join`. - * - * @group Query - */ - def join( - otherPlan: SchemaRDD, - joinType: JoinType = Inner, - on: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) - - /** - * Sorts the results by the given expressions. - * {{{ - * schemaRDD.orderBy('a) - * schemaRDD.orderBy('a, 'b) - * schemaRDD.orderBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) - - /** - * Sorts the results by the given expressions within partition. - * {{{ - * schemaRDD.sortBy('a) - * schemaRDD.sortBy('a, 'b) - * schemaRDD.sortBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) - - @deprecated("use limit with integer argument", "1.1.0") - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) - - /** - * Limits the results by the given integer. - * {{{ - * schemaRDD.limit(10) - * }}} - * @group Query - */ - def limit(limitNum: Int): SchemaRDD = - new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) - - /** - * Performs a grouping followed by an aggregation. - * - * {{{ - * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() - } - new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) - } - - /** - * Performs an aggregation over all Rows in this RDD. - * This is equivalent to a groupBy with no grouping expressions. - * - * {{{ - * schemaRDD.aggregate(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def aggregate(aggregateExprs: Expression*): SchemaRDD = { - groupBy()(aggregateExprs: _*) - } - - /** - * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes - * with the same name, for example, when performing self-joins. - * - * {{{ - * val x = schemaRDD.where('a === 1).as('x) - * val y = schemaRDD.where('a === 2).as('y) - * x.join(y).where("x.a".attr === "y.a".attr), - * }}} - * - * @group Query - */ - def as(alias: Symbol) = - new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) - - /** - * Combines the tuples of two RDDs with the same schema, keeping duplicates. - * - * @group Query - */ - def unionAll(otherPlan: SchemaRDD) = - new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational except on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. - * - * @group Query - */ - def except(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational intersect on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. - * - * @group Query - */ - def intersect(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) - - /** - * Filters tuples using a function over the value of the specified column. - * - * {{{ - * schemaRDD.where('a)((a: Int) => ...) - * }}} - * - * @group Query - */ - def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - - /** - * :: Experimental :: - * Returns a sampled version of the underlying dataset. - * - * @group Query - */ - @Experimental - override - def sample( - withReplacement: Boolean = true, - fraction: Double, - seed: Long) = - new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) - - /** - * :: Experimental :: - * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this - * implementation leverages the query optimizer to compute the count on the SchemaRDD, which - * supports features such as filter pushdown. - * - * @group Query - */ - @Experimental - override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) - - /** - * :: Experimental :: - * Applies the given Generator, or table generating function, to this relation. - * - * @param generator A table generating function. The API for such functions is likely to change - * in future releases - * @param join when set to true, each output row of the generator is joined with the input row - * that produced it. - * @param outer when set to true, at least one row will be produced for each input row, similar to - * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a - * given row, a single row will be output, with `NULL` values for each of the - * generated columns. - * @param alias an optional alias that can be used as qualifier for the attributes that are - * produced by this generate operation. - * - * @group Query - */ - @Experimental - def generate( - generator: Generator, - join: Boolean = false, - outer: Boolean = false, - alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) - - /** - * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit - * conversion from a standard RDD to a SchemaRDD. - * - * @group schema - */ - def toSchemaRDD = this - - /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. - */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same - * format as javaToPython. It is used by pyspark. - */ - private[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same - * format as javaToPython and collectToPython. It is used by pyspark. - */ - private[sql] def takeSampleToPython( - withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value - * of base RDD functions that do not change schema. - * - * @param rdd RDD derived from this one and has same schema - * - * @group schema - */ - private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, - LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) - } - - // ======================================================================= - // Overridden RDD actions - // ======================================================================= - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*) - - override def take(num: Int): Array[Row] = limit(num).collect() - - // ======================================================================= - // Base RDD functions that do NOT change schema - // ======================================================================= - - // Transformations (return a new RDD) - - override def coalesce(numPartitions: Int, shuffle: Boolean = false) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.coalesce(numPartitions, shuffle)(ord)) - - override def distinct(): SchemaRDD = applySchema(super.distinct()) - - override def distinct(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.distinct(numPartitions)(ord)) - - def distinct(numPartitions: Int): SchemaRDD = - applySchema(super.distinct(numPartitions)(null)) - - override def filter(f: Row => Boolean): SchemaRDD = - applySchema(super.filter(f)) - - override def intersection(other: RDD[Row]): SchemaRDD = - applySchema(super.intersection(other)) - - override def intersection(other: RDD[Row], partitioner: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.intersection(other, partitioner)(ord)) - - override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.intersection(other, numPartitions)) - - override def repartition(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.repartition(numPartitions)(ord)) - - override def subtract(other: RDD[Row]): SchemaRDD = - applySchema(super.subtract(other)) - - override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.subtract(other, numPartitions)) - - override def subtract(other: RDD[Row], p: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.subtract(other, p)(ord)) - - /** Overridden cache function will always use the in-memory columnar caching. */ - override def cache(): this.type = { - sqlContext.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.tryUncacheQuery(this, blocking) - this - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala deleted file mode 100644 index 3cf9209465b76..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD - -/** - * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) - */ -private[sql] trait SchemaRDDLike { - @transient def sqlContext: SQLContext - @transient val baseLogicalPlan: LogicalPlan - - private[sql] def baseSchemaRDD: SchemaRDD - - /** - * :: DeveloperApi :: - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. This workflow is produced lazily because - * invoking the whole query optimization pipeline can be expensive. - * - * The query execution is considered a Developer API as phases may be added or removed in future - * releases. This execution is only exposed to provide an interface for inspecting the various - * phases for debugging purposes. Applications should not depend on particular phases existing - * or producing any specific output, even for exactly the same query. - * - * Additionally, the RDD exposed by this execution is not designed for consumption by end users. - * In particular, it does not contain any schema information, and it reuses Row objects - * internally. This object reuse improves performance, but can make programming against the RDD - * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly. - */ - @transient - @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } - - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.simpleString}""".stripMargin.trim - - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) - } - - @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") - def registerAsTable(tableName: String): Unit = registerTempTable(tableName) - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - - /** - * :: Experimental :: - * Appends the rows from this RDD to the specified table. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * :: Experimental :: - * Creates a table from the the contents of this SchemaRDD. This will fail if the table already - * exists. - * - * Note that this currently only works with SchemaRDDs that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * @group schema - */ - @Experimental - def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd - - /** Returns the schema as a string in the tree format. - * - * @group schema - */ - def schemaString: String = baseSchemaRDD.schema.treeString - - /** Prints out the schema. - * - * @group schema - */ - def printSchema(): Unit = println(schemaString) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala new file mode 100644 index 0000000000000..073d41e938478 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -0,0 +1,289 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +trait RDDApi[T] { + + def cache(): this.type = persist() + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type = unpersist(blocking = false) + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame +} + + +/** + * An internal interface defining data frame related methods in [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +trait DataFrameSpecificApi { + + def schema: StructType + + def printSchema(): Unit + + def dtypes: Array[(String, String)] + + def columns: Array[String] + + def head(): Row + + def head(n: Int): Array[Row] + + ///////////////////////////////////////////////////////////////////////////// + // Relational operators + ///////////////////////////////////////////////////////////////////////////// + def apply(colName: String): Column + + def apply(projection: Product): DataFrame + + @scala.annotation.varargs + def select(cols: Column*): DataFrame + + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame + + def apply(condition: Column): DataFrame + + def as(name: String): DataFrame + + def filter(condition: Column): DataFrame + + def where(condition: Column): DataFrame + + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataFrame + + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): GroupedDataFrame + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def sort(colName: String): DataFrame + + @scala.annotation.varargs + def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortExpr: Column, sortExprs: Column*): DataFrame + + def join(right: DataFrame): DataFrame + + def join(right: DataFrame, joinExprs: Column): DataFrame + + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + + def limit(n: Int): DataFrame + + def unionAll(other: DataFrame): DataFrame + + def intersect(other: DataFrame): DataFrame + + def except(other: DataFrame): DataFrame + + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + + def sample(withReplacement: Boolean, fraction: Double): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // Column mutation + ///////////////////////////////////////////////////////////////////////////// + def addColumn(colName: String, col: Column): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // I/O and interaction with other frameworks + ///////////////////////////////////////////////////////////////////////////// + + def rdd: RDD[Row] + + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + def toJSON: RDD[String] + + def registerTempTable(tableName: String): Unit + + def saveAsParquetFile(path: String): Unit + + @Experimental + def saveAsTable(tableName: String): Unit + + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit + + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + ///////////////////////////////////////////////////////////////////////////// + // Stat functions + ///////////////////////////////////////////////////////////////////////////// +// def describe(): Unit +// +// def mean(): Unit +// +// def max(): Unit +// +// def min(): Unit +} + + +/** + * An internal interface defining expression APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this. + */ +trait ExpressionApi { + + def isComputable: Boolean + + def unary_- : Column + def unary_! : Column + def unary_~ : Column + + def + (other: Column): Column + def + (other: Any): Column + def - (other: Column): Column + def - (other: Any): Column + def * (other: Column): Column + def * (other: Any): Column + def / (other: Column): Column + def / (other: Any): Column + def % (other: Column): Column + def % (other: Any): Column + def & (other: Column): Column + def & (other: Any): Column + def | (other: Column): Column + def | (other: Any): Column + def ^ (other: Column): Column + def ^ (other: Any): Column + + def && (other: Column): Column + def && (other: Boolean): Column + def || (other: Column): Column + def || (other: Boolean): Column + + def < (other: Column): Column + def < (other: Any): Column + def <= (other: Column): Column + def <= (other: Any): Column + def > (other: Column): Column + def > (other: Any): Column + def >= (other: Column): Column + def >= (other: Any): Column + def === (other: Column): Column + def === (other: Any): Column + def equalTo(other: Column): Column + def equalTo(other: Any): Column + def <=> (other: Column): Column + def <=> (other: Any): Column + def !== (other: Column): Column + def !== (other: Any): Column + + @scala.annotation.varargs + def in(list: Column*): Column + + def like(other: Column): Column + def like(other: String): Column + def rlike(other: Column): Column + def rlike(other: String): Column + + def contains(other: Column): Column + def contains(other: Any): Column + def startsWith(other: Column): Column + def startsWith(other: String): Column + def endsWith(other: Column): Column + def endsWith(other: String): Column + + def substr(startPos: Column, len: Column): Column + def substr(startPos: Int, len: Int): Column + + def isNull: Column + def isNotNull: Column + + def getItem(ordinal: Column): Column + def getItem(ordinal: Int): Column + def getField(fieldName: String): Column + + def cast(to: DataType): Column + + def asc: Column + def desc: Column + + def as(alias: String): Column +} + + +/** + * An internal interface defining aggregation APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this. + */ +trait GroupedDataFrameApi { + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def avg(): DataFrame + + def mean(): DataFrame + + def min(): DataFrame + + def max(): DataFrame + + def sum(): DataFrame + + def count(): DataFrame + + // TODO: Add var, std +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala new file mode 100644 index 0000000000000..29c3d26ae56d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + + +package object dsl { + + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) extends AnyVal { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + + def sum(e: Column): Column = Sum(e.expr) + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def count(e: Column): Column = Count(e.expr) + + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + def avg(e: Column): Column = Average(e.expr) + def first(e: Column): Column = First(e.expr) + def last(e: Column): Column = Last(e.expr) + def min(e: Column): Column = Min(e.expr) + def max(e: Column): Column = Max(e.expr) + def upper(e: Column): Column = Upper(e.expr) + def lower(e: Column): Column = Lower(e.expr) + def sqrt(e: Column): Column = Sqrt(e.expr) + def abs(e: Column): Column = Abs(e.expr) + + // scalastyle:off + + object literals { + + implicit def booleanToLiteral(b: Boolean): Column = Literal(b) + + implicit def byteToLiteral(b: Byte): Column = Literal(b) + + implicit def shortToLiteral(s: Short): Column = Literal(s) + + implicit def intToLiteral(i: Int): Column = Literal(i) + + implicit def longToLiteral(l: Long): Column = Literal(l) + + implicit def floatToLiteral(f: Float): Column = Literal(f) + + implicit def doubleToLiteral(d: Double): Column = Literal(d) + + implicit def stringToLiteral(s: String): Column = Literal(s) + + implicit def dateToLiteral(d: Date): Column = Literal(d) + + implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying()) + + implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d) + + implicit def timestampToLiteral(t: Timestamp): Column = Literal(t) + + implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a) + } + + + /* Use the following code to generate: + (0 to 22).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) + }""") + } + + (0 to 22).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag](f: Function0[RT]): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 52a31f01a4358..6fba76c52171b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical @@ -137,7 +137,9 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + plan.foreach { logicalPlan => + sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName) + } sqlContext.cacheTable(tableName) if (!isLazy) { @@ -159,7 +161,7 @@ case class CacheTableCommand( case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - sqlContext.table(tableName).unpersist() + sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 4d7e338e8ed13..aeb0960e87f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -42,7 +42,7 @@ package object debug { * Augments SchemaRDDs with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: SchemaRDD) { + implicit class DebugQuery(query: DataFrame) { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6dd39be807037..7c49b5220d607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -37,5 +37,5 @@ package object sql { * Converts a logical plan into zero or more SparkPlans. */ @DeveloperApi - type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 02ce1b3e6d811..0b312ef51daa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils @@ -100,7 +100,7 @@ trait ParquetTest { */ protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) - (f: SchemaRDD => Unit): Unit = { + (f: DataFrame => Unit): Unit = { withParquetFile(data)(path => f(parquetFile(path))) } @@ -120,7 +120,7 @@ trait ParquetTest { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetRDD(data) { rdd => - rdd.registerTempTable(tableName) + sqlContext.registerRDDAsTable(rdd, tableName) withTempTable(tableName)(f) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d03019..d13f2ce2a5e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, @@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy { } } + /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => + LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => + GreaterThanOrEqual(a.name, v) case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 171b816a26332..b4af91a768efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SchemaRDD, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand @@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) + sqlContext.registerRDDAsTable( + new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f9c082216085d..2564c849b87f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ @@ -40,8 +40,8 @@ object TestSQLContext * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to * construct SchemaRDD directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { - new SchemaRDD(this, plan) + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + new DataFrame(this, plan) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java index 9ff40471a00af..e5588938ea162 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -61,7 +61,7 @@ public Integer call(String str) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); assert(result.getInt(0) == 4); } @@ -81,7 +81,7 @@ public Integer call(String str1, String str2) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); assert(result.getInt(0) == 9); } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9e96738ac095a..badd00d34b9b1 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,8 +98,8 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - schemaRDD.registerTempTable("people"); + DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); + df.registerTempTable("people"); Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -147,17 +147,17 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = schemaRDD1.schema(); + DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); + StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerTempTable("jsonTable1"); + df1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = schemaRDD2.schema(); + DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); + StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD2.registerTempTable("jsonTable2"); + df2.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index cfc037caff2a9..34763156a6d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index afbfe214f1ce4..a5848f219cea9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types._ /* Implicits */ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import scala.language.postfixOps @@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, sum('b)), + testData2.groupBy("a").agg($"a", sum($"b")), Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) checkAnswer( - testData2.aggregate(sum('b)), + testData2.agg(sum('b)), Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( - testData.where($"key" === 1).select($"value"), + testData.where($"key" === Literal(1)).select($"value"), Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === Literal(1)).select('value), Row("1")) } test("select *") { checkAnswer( - testData.select(Star(None)), + testData.select($"*"), testData.collect().toSeq) } test("simple select") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === Literal(1)).select('value), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(1)), + testData.select(sum('value), avg('value), count(Literal(1))), Row(5050.0, 50.5, 100)) checkAnswer( @@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) - } - - test("partition wide sorting") { - // 2 partitions totally, and - // Partition #1 with values: - // (1, 1) - // (1, 2) - // (2, 1) - // Partition #2 with values: - // (2, 2) - // (3, 1) - // (3, 2) - checkAnswer( - testData2.sortBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.sortBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.desc), - Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.asc), - Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest { mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("SPARK-3395 limit distinct") { - val filtered = TestData.testData2 - .distinct() - .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) - .limit(1) - .registerTempTable("onerow") - checkAnswer( - sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: Nil) - } - - test("SPARK-3858 generator qualifiers are discarded") { - checkAnswer( - arrayData.as('ad) - .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) - .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Row(_))) - } - test("average") { checkAnswer( - testData2.aggregate(avg('a)), + testData2.agg(avg('a)), Row(2.0)) checkAnswer( - testData2.aggregate(avg('a), sumDistinct('a)), // non-partial + testData2.agg(avg('a), sumDistinct('a)), // non-partial Row(2.0, 6.0) :: Nil) checkAnswer( - decimalData.aggregate(avg('a)), + decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial + decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2))), + decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( - testData3.aggregate(avg('b)), + testData3.agg(avg('b)), Row(2.0)) checkAnswer( - testData3.aggregate(avg('b), countDistinct('b)), + testData3.agg(avg('b), countDistinct('b)), Row(2.0, 1)) checkAnswer( - testData3.aggregate(avg('b), sumDistinct('b)), // non-partial + testData3.agg(avg('b), sumDistinct('b)), // non-partial Row(2.0, 2.0)) } test("zero average") { checkAnswer( - emptyTableData.aggregate(avg('a)), + emptyTableData.agg(avg('a)), Row(null)) checkAnswer( - emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial Row(null, null)) } @@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest { assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( - testData2.aggregate(count('a), sumDistinct('a)), // non-partial + testData2.agg(count('a), sumDistinct('a)), // non-partial Row(6, 6.0)) } test("null count") { checkAnswer( - testData3.groupBy('a)('a, count('b)), + testData3.groupBy('a).agg('a, count('b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.groupBy('a)('a, count('a + 'b)), + testData3.groupBy('a).agg('a, count('a + 'b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), + testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)), Row(2, 1, 2, 2, 1) ) checkAnswer( - testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial Row(1, 1, 2) ) } @@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest { assert(emptyTableData.count() === 0) checkAnswer( - emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("zero sum") { checkAnswer( - emptyTableData.aggregate(sum('a)), + emptyTableData.agg(sum('a)), Row(null)) } test("zero sum distinct") { checkAnswer( - emptyTableData.aggregate(sumDistinct('a)), + emptyTableData.agg(sumDistinct('a)), Row(null)) } @@ -320,7 +271,7 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select(Star(None), foo.call('key, 'value)).limit(3), + testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } @@ -362,7 +313,7 @@ class DslQuerySuite extends QueryTest { test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase())) + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) ) checkAnswer( @@ -379,7 +330,7 @@ class DslQuerySuite extends QueryTest { test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase())) + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index cd36da7751e83..79713725c0d77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } test("inner join where, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + upperCaseData.join(lowerCaseData).where('n === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("inner join ON, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) + val x = testData2.where($"a" === Literal(1)).as("x") + val y = testData2.where($"a" === Literal(1)).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Row(1,1,1,1) :: Row(1,1,1,2) :: Row(1,2,1,1) :: @@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) + val x = testData2.where($"a" === Literal(1)).as("x") + val y = testData2.where($"a" === Literal(2)).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Nil) } test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + bigDataX.join(bigDataY).where($"x.key" === $"y.key"), + testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { @@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left outer join") { checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("right outer join") { checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -299,14 +298,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= Literal(4)).registerTempTable("left") + upperCaseData.where('N >= Literal(3)).registerTempTable("right") val left = UnresolvedRelation(Seq("left"), None) val right = UnresolvedRelation(Seq("right"), None) checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + left.join(right, $"left.N" === $"right.N", "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", 3, "C") :: @@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 42a21c148df53..07c52de377a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -26,12 +26,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -44,10 +44,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -91,7 +91,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } @@ -102,7 +102,7 @@ class QueryTest extends PlanTest { } /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..4fff99cb3f3e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -381,8 +383,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("big inner join, 4 matches per row") { - - checkAnswer( sql( """ @@ -396,7 +396,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.flatMap( + testData.rdd.flatMap( row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } @@ -742,7 +742,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("metadata is propagated correctly") { - val person = sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -751,14 +751,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person, schemaWithMeta) - def validateMetadata(rdd: SchemaRDD): Unit = { + val personWithMeta = applySchema(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } personWithMeta.registerTempTable("personWithMeta") - validateMetadata(personWithMeta.select('name)) - validateMetadata(personWithMeta.select("name".attr)) - validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"id", $"name")) validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 808ed5288cfb8..fffa2b7dfa6e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test._ /* Implicits */ @@ -29,11 +30,11 @@ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD + (1 to 100).map(i => TestData(i, i.toString))).toDF testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +45,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD + LargeAndSmallInts(3, 2) :: Nil).toDF largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +56,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toDF testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +68,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toSchemaRDD + DecimalData(3, 2) :: Nil).toDF decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,17 +78,17 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD + BinaryData("123".getBytes(), 4) :: Nil).toDF binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toSchemaRDD + TestData3(2, Some(2)) :: Nil).toDF testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation($"a".int, $"b".int) case class UpperCaseData(N: Int, L: String) val upperCaseData = @@ -97,7 +98,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toSchemaRDD + UpperCaseData(6, "F") :: Nil).toDF upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +107,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toSchemaRDD + LowerCaseData(4, "d") :: Nil).toDF lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -200,6 +201,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toSchemaRDD + :: Nil).toDF complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0c98120031242..5abd7b9383366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ @@ -28,17 +29,17 @@ class UDFSuite extends QueryTest { test("Simple UDF") { udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { @@ -46,7 +47,7 @@ class UDFSuite extends QueryTest { val result= sql("SELECT returnStruct('test', 'test2') as ret") - .select("ret.f1".attr).first().getString(0) - assert(result == "test") + .select($"ret.f1").head().getString(0) + assert(result === "test") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fbc8704f7837b..62b2e89403791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index e61f3c39631da..6f051dfe3d21d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 67007b8c093ca..be5e63c76f42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } @@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite { testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") - val a = testData.as('a) - val b = table("tiny").as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + val a = testData.as("a") + val b = table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala deleted file mode 100644 index 272c0d4cb2335..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ - -/* Implicit conversions */ -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns - * from the input data. These will be replaced during analysis with specific AttributeReferences - * and then bound to specific ordinals during query planning. While TGFs could also access specific - * columns using hand-coded ordinals, doing so violates data independence. - * - * Note: this is only a rough example of how TGFs can be expressed, the final version will likely - * involve a lot more sugar for cleaner use in Scala/Java/etc. - */ -case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { - def children = input - protected def makeOutput() = 'nameAndAge.string :: Nil - - val Seq(nameAttr, ageAttr) = input - - override def eval(input: Row): TraversableOnce[Row] = { - val name = nameAttr.eval(input) - val age = ageAttr.eval(input).asInstanceOf[Int] - - Iterator( - new GenericRow(Array[Any](s"$name is $age years old")), - new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) - } -} - -class TgfSuite extends QueryTest { - val inputData = - logical.LocalRelation('name.string, 'age.int).loadData( - ("michael", 29) :: Nil - ) - - test("simple tgf example") { - checkAnswer( - inputData.generate(ExampleTGF()), - Seq( - Row("michael is 29 years old"), - Row("Next year, michael will be 30 years old"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 94d14acccbb18..ef198f846c53a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.{Literal, QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ @@ -463,8 +464,8 @@ class JsonSuite extends QueryTest { // in the Project. checkAnswer( jsonSchemaRDD. - where('num_str > BigDecimal("92233720368547758060")). - select('num_str + 1.2 as Symbol("num")), + where('num_str > Literal(BigDecimal("92233720368547758060"))). + select(('num_str + Literal(1.2)).as("num")), Row(new java.math.BigDecimal("92233720368547758061.2")) ) @@ -820,7 +821,7 @@ class JsonSuite extends QueryTest { val schemaRDD1 = applySchema(rowRDD1, schema1) schemaRDD1.registerTempTable("applySchema1") - val schemaRDD2 = schemaRDD1.toSchemaRDD + val schemaRDD2 = schemaRDD1.toDF val result = schemaRDD2.toJSON.collect() assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") @@ -841,7 +842,7 @@ class JsonSuite extends QueryTest { val schemaRDD3 = applySchema(rowRDD2, schema2) schemaRDD3.registerTempTable("applySchema2") - val schemaRDD4 = schemaRDD3.toSchemaRDD + val schemaRDD4 = schemaRDD3.toDF val result2 = schemaRDD4.toJSON.collect() assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 1e7d3e06fc196..c9bc55900de98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -23,7 +23,7 @@ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -41,15 +41,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: SchemaRDD, + rdd: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Seq[Row]) => Unit, + checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output: _*).where(predicate) + val query = rdd + .select(output.map(e => new org.apache.spark.sql.Column(e)): _*) + .where(new org.apache.spark.sql.Column(predicate)) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -71,13 +73,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -93,24 +95,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -118,24 +120,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -143,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -168,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -197,30 +199,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } @@ -231,7 +233,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -249,16 +251,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a57e4e85a35ef..f03b3a32e34e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): SchemaRDD = + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select('_1 cast decimal) + .select($"_1" cast decimal as "abcd") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 7900b3e8948d9..a33cf1172cac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import scala.language.existentials + import org.apache.spark.sql._ import org.apache.spark.sql.types._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7385952861ee5..bb19ac232fcbe 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -23,6 +23,7 @@ import java.io._ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 166c56b9dfe20..ea9d61d8d0f5e 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation( sessionToActivePool: SMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index eaf7a1ddd4996..71e3954b2c7ac 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation( // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 274f83af5ac03..b746942cb1067 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} @@ -63,14 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) - override def sql(sqlText: String): SchemaRDD = { + override def sql(sqlText: String): DataFrame = { + val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { - super.sql(sqlText) + super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) + new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } @@ -350,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] val planner = hivePlanner /** Extends QueryExecution with hive specific features. */ - protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected[sql] class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { /** * Returns the result as a hive compatible sequence of strings. For native commands, the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6952b126cf894..ace9329cd5821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -55,16 +55,15 @@ private[hive] trait HiveStrategies { */ @Experimental object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: SchemaRDD) { - def lowerCase = - new SchemaRDD(s.sqlContext, s.logicalPlan) + implicit class LogicalPlanHacks(s: DataFrame) { + def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new SchemaRDD( + new DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = otherPredicates.map(_ transform { + val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true)) + }).reduceOption(And).getOrElse(Literal(true))) - val unresolvedProjection = projectList.map(_ transform { + val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }) + }).map(new Column(_)) try { if (relation.hiveQlTable.isPartitioned) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 47431cef03e13..8e70ae8f56196 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -150,8 +150,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(hql) + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(hql)) { def hiveExec() = runSqlHive(hql) override def toString = hql + "\n" + super.toString } @@ -159,7 +159,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** * Override QueryExecution with special debug workflow. */ - abstract class QueryExecution extends super.QueryExecution { + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { override lazy val analyzed = { val describedTables = logical match { case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f320d732fb77a..ba391293884bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -36,12 +36,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -54,10 +54,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -101,7 +101,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f95a6b43af357..61e5117feab10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { @@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest { * Throws a test failed exception when the number of cached tables differs from the expected * number. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 0e6636d38ed3c..5775d83fcbf67 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq + testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq ) // Now overwrite. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac6..d67b00bc9d08f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,11 +27,12 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -473,7 +474,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: SchemaRDD) = { + def isExplanation(result: DataFrame) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } @@ -842,7 +843,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = + def collectResults(rdd: DataFrame): Set[(String, String)] = rdd.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 16f77a438e1ae..a081227b4e6b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("create table spark_4959 (col1 string)") sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( - 'col1.as('CaseSensitiveColName), - 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f2374a215291b..dd0df1a9f6320 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f3, | getStruct(1).f4, | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6bf2dbb5d6e4..7f9f1ac7cd80d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest { ) } + test("command substitution") { + sql("set tbl=src") + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") + intercept[Exception] { + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + } + + sql("set hive.variable.substitute=true") // enable the substitution + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index e0542eda1383f..c382a12f4d099 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @deprecated As this API is not Java compatible. */ + @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") def reduceByWindow( reduceFunc: (T, T) => T, windowDuration: Duration, @@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. However, the reduction is done incrementally diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..8632c94349bf9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{ActorRef, ActorSystem, Props, Actor} -import org.apache.spark.{SparkException, SparkEnv, Logging} -import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} -import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} import scala.util.{Failure, Success, Try} +import akka.actor.{ActorRef, Props, Actor} + +import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} + /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent @@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => + timesToReschedule.foreach { time => + // Allocate the related blocks when recovering from failure, because some blocks that were + // added but not allocated, are dangling in the queue after recovering, we have to allocate + // those blocks to the next batch, which is the batch they were supposed to go. + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time))) - ) + } // Restart the timer timer.start(restartTime.milliseconds) @@ -238,13 +244,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +262,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..e19ac939f9ac5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker( extends Logging { private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] - + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] private val logManagerOption = createLogManager() @@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker( lastAllocatedBatchTime = batchTime allocatedBlocks } else { - throw new SparkException(s"Unexpected allocation of blocks, " + - s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + // This situation occurs when: + // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, + // possibly processed batch job or half-processed batch job need to be processed again, + // so the batchTime will be equal to lastAllocatedBatchTime. + // 2. Slow checkpointing makes recovered batch time older than WAL recovered + // lastAllocatedBatchTime. + // This situation will only occurs in recovery time. + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") } } @@ -150,7 +156,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index d92e7fe899a09..d4c40745658c2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -306,7 +306,17 @@ public void testReduce() { @SuppressWarnings("unchecked") @Test - public void testReduceByWindow() { + public void testReduceByWindowWithInverse() { + testReduceByWindow(true); + } + + @SuppressWarnings("unchecked") + @Test + public void testReduceByWindowWithoutInverse() { + testReduceByWindow(false); + } + + private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -319,8 +329,14 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + JavaDStream reducedWindowed = null; + if (withInverse) { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + } else { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new Duration(2000), new Duration(1000)); + } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 40434b1f9b709..6500608bba87c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -28,21 +28,16 @@ import java.io.File */ class FailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + val directory = Utils.createTempDir().getAbsolutePath val numBatches = 30 override def batchDuration = Milliseconds(1000) override def useManualClock = false - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } - override def afterFunction() { - super.afterFunction() Utils.deleteRecursively(new File(directory)) + super.afterFunction() } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index de7e9d624bf6b..fbb7b0bfebafc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.allocateBlocksToBatch(2) receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty - // Verify that batch 2 cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(2) - } + // Verify that older batches have no operation on batch allocation, + // will return the same blocks as previously allocated. + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos - // Verify that older batches cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(1) - } + blockInfos.map(receivedBlockTracker.addBlock) + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } test("block addition, block to batch allocation and cleanup with write ahead log") { @@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } - + test("enabling write ahead log but not setting checkpoint dir") { conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") intercept[SparkException] { createTracker(setCheckpointDir = false) } } - + test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4c35b60c57df3..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -60,7 +60,6 @@ private[yarn] class YarnAllocator( import YarnAllocator._ - // These two complementary data structures are locked on allocatedHostToContainersMap. // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] @@ -355,20 +354,18 @@ private[yarn] class YarnAllocator( } } - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) } + + allocatedContainerToHostMap.remove(containerId) } } }
    Property NameDefaultMeaning
    spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. +
    spark.driver.extraClassPath(none) + Extra classpath entries to append to the classpath of the driver. +
    spark.driver.extraLibraryPath(none) + Set a special library path to use when launching the driver JVM. +
    spark.executor.extraJavaOptions (none)