diff --git a/assembly/pom.xml b/assembly/pom.xml index 4f6aade133db7..567a8dd2a0d94 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -39,6 +39,7 @@ spark /usr/share/spark root + 744 @@ -276,7 +277,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/bin - 744 + ${deb.bin.filemode} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 70a99b33d753c..ef0bb2ac13f08 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -72,6 +72,7 @@ object Bagel extends Logging { var verts = vertices var msgs = messages var noActivity = false + var lastRDD: RDD[(K, (V, Array[M]))] = null do { logInfo("Starting superstep " + superstep + ".") val startTime = System.currentTimeMillis @@ -83,6 +84,10 @@ object Bagel extends Logging { val superstep_ = superstep // Create a read-only copy of superstep for capture in closure val (processed, numMsgs, numActiveVerts) = comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) + if (lastRDD != null) { + lastRDD.unpersist(false) + } + lastRDD = processed val timeTaken = System.currentTimeMillis - startTime logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3d8373d8175ee..3b5642b6caa36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -269,6 +269,9 @@ object SparkSubmit { sysProps.getOrElseUpdate(k, v) } + // Spark properties included on command line take precedence + sysProps ++= args.sparkProperties + (childArgs, childClasspath, sysProps, childMainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57655aa4c32b1..3ab67a43a3b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -55,6 +55,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) loadDefaults() @@ -177,6 +178,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile + | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -290,6 +292,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { jars = Utils.resolveURIs(value) parse(tail) + case ("--conf" | "-c") :: value :: tail => + value.split("=", 2).toSeq match { + case Seq(k, v) => sparkProperties(k) = v + case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") + } + parse(tail) + case ("--help" | "-h") :: tail => printUsageAndExit(0) @@ -349,6 +358,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working | directory of each executor. + | + | --conf PROP=VALUE Arbitrary Spark configuration property. | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a8c9ac072449f..01e7065c17b69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -169,7 +169,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui: SparkUI = if (renderUI) { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) + new SparkUI(conf, appSecManager, replayBus, appId, + HistoryServer.UI_PATH_PREFIX + s"/$appId") // Do not call ui.bind() to avoid creating a new server for each application } else { null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index a958c837c2ff6..d7a3e3f120e67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -75,7 +75,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Last Updated") private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val uiAddress = "/history/" + info.id + val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}" val startTime = UIUtils.formatDate(info.startTime) val endTime = UIUtils.formatDate(info.endTime) val duration = UIUtils.formatDuration(info.endTime - info.startTime) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 56b38ddfc9313..cacb9da8c947b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -114,7 +114,7 @@ class HistoryServer( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) val contextHandler = new ServletContextHandler - contextHandler.setContextPath("/history") + contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) contextHandler.addServlet(new ServletHolder(loaderServlet), "/*") attachHandler(contextHandler) } @@ -172,6 +172,8 @@ class HistoryServer( object HistoryServer extends Logging { private val conf = new SparkConf + val UI_PATH_PREFIX = "/history" + def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index bb1fcc8190fe4..21f8667819c44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -35,6 +35,7 @@ import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI @@ -664,9 +665,10 @@ private[spark] class Master( */ def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name + val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" val eventLogDir = app.desc.eventLogDir.getOrElse { // Event logging is not enabled for this application - app.desc.appUiUrl = "/history/not-found" + app.desc.appUiUrl = notFoundBasePath return false } val fileSystem = Utils.getHadoopFileSystem(eventLogDir) @@ -681,13 +683,14 @@ private[spark] class Master( logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = s"/history/not-found?msg=$msg&title=$title" + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" return false } try { val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) + val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", + HistoryServer.UI_PATH_PREFIX + s"/${app.id}") replayBus.replay() appIdToUI(app.id) = ui webUi.attachSparkUI(ui) @@ -702,7 +705,7 @@ private[spark] class Master( var msg = s"Exception in replaying log for application $appName!" logError(msg, e) msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = s"/history/not-found?msg=$msg&exception=$exception&title=$title" + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a90b0d475c04e..ae6ca9f4e7bf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -63,6 +63,13 @@ private[spark] class EventLoggingListener( // For testing. Keep track of all JSON serialized events that have been logged. private[scheduler] val loggedEvents = new ArrayBuffer[JValue] + /** + * Return only the unique application directory without the base directory. + */ + def getApplicationLogDir(): String = { + name + } + /** * Begin logging events. * If compression is used, log a file that indicates which compression library is used. diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index e9f6273bfd9f0..5b897597fa285 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -57,7 +57,7 @@ private[spark] class LocalActor( case StatusUpdate(taskId, state, serializedData) => scheduler.statusUpdate(taskId, state, serializedData) if (TaskState.isFinished(state)) { - freeCores += 1 + freeCores += scheduler.CPUS_PER_TASK reviveOffers() } @@ -68,7 +68,7 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { - freeCores -= 1 + freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1ce4243194798..c3a3e90a34901 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -48,6 +48,7 @@ class KryoSerializer(conf: SparkConf) private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) + private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") def newKryoOutput() = new KryoOutput(bufferSize) @@ -55,6 +56,7 @@ class KryoSerializer(conf: SparkConf) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() + kryo.setRegistrationRequired(registrationRequired) val classLoader = Thread.currentThread.getContextClassLoader // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. @@ -185,7 +187,8 @@ private[serializer] object KryoSerializer { classOf[MapStatus], classOf[BlockManagerId], classOf[Array[Byte]], - classOf[BoundedPriorityQueue[_]] + classOf[BoundedPriorityQueue[_]], + classOf[SparkConf] ) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5f45c0ced5ec5..f8b308c981548 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import scala.xml.Node +import scala.xml.Text import java.util.Date @@ -99,19 +100,30 @@ private[ui] class StageTableBase( {s.name} + val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { - +show details - - + +details + ++ + } val stageDataOption = listener.stageIdToData.get(s.stageId) // Too many nested map/flatMaps with options are just annoying to read. Do this imperatively. if (stageDataOption.isDefined && stageDataOption.get.description.isDefined) { val desc = stageDataOption.get.description -
{desc}
{nameLink} {killLink}
+
{desc}
{killLink} {nameLink} {details}
} else {
{killLink} {nameLink} {details}
} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 3448aaaf5724c..bb6079154aafe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -257,7 +257,8 @@ private[spark] object JsonProtocol { val reason = Utils.getFormattedClassName(taskEndReason) val json = taskEndReason match { case fetchFailed: FetchFailed => - val blockManagerAddress = blockManagerIdToJson(fetchFailed.bmAddress) + val blockManagerAddress = Option(fetchFailed.bmAddress). + map(blockManagerIdToJson).getOrElse(JNothing) ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 565c53e9529ff..f497a5e0a14f0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -120,6 +120,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -139,6 +140,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") + sysProps("spark.shuffle.spill") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") } @@ -156,6 +158,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -176,6 +179,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone cluster mode") { @@ -186,6 +190,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--supervise", "--driver-memory", "4g", "--driver-cores", "5", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -195,9 +200,10 @@ class SparkSubmitSuite extends FunSuite with Matchers { childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") mainClass should be ("org.apache.spark.deploy.Client") classpath should have size (0) - sysProps should have size (2) + sysProps should have size (3) sysProps.keys should contain ("spark.jars") sysProps.keys should contain ("SPARK_SUBMIT") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone client mode") { @@ -208,6 +214,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -218,6 +225,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("handles mesos client mode") { @@ -228,6 +236,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -238,6 +247,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("launch simple application with spark-submit") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 86b443b18f2a6..c52368b5514db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -475,6 +475,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY assert(manager.myLocalityLevels.sameElements( Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + FakeRackUtil.cleanUp() } test("test RACK_LOCAL tasks") { @@ -505,6 +506,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2 // Task 1 can be scheduled with RACK_LOCAL assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) + FakeRackUtil.cleanUp() } test("do not emit warning when serialized task is small") { diff --git a/docs/configuration.md b/docs/configuration.md index a70007c165442..cb0c65e2d2200 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -42,13 +42,15 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My fancy app" --master local[4] myApp.jar +./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false + --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) tool support two ways to load configurations dynamically. The first are command line options, -such as `--master`, as shown above. Running `./bin/spark-submit --help` will show the entire list -of options. +such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` +flag, but uses special flags for properties that play a part in launching the Spark application. +Running `./bin/spark-submit --help` will show the entire list of these options. `bin/spark-submit` will also read configuration options from `conf/spark-defaults.conf`, in which each line consists of a key and a value separated by whitespace. For example: @@ -388,6 +390,17 @@ Apart from these, the following properties are also available, and may be useful case. + + spark.kryo.registrationRequired + false + + Whether to require registration with Kryo. If set to 'true', Kryo will throw an exception + if an unregistered class is serialized. If set to false (the default), Kryo will write + unregistered class names along with each object. Writing class names can cause + significant performance overhead, so enabling this option can enforce strictly that a + user has not omitted classes from registration. + + spark.kryoserializer.buffer.mb 2 @@ -497,9 +510,9 @@ Apart from these, the following properties are also available, and may be useful spark.hadoop.validateOutputSpecs true - If set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with + If set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. @@ -861,7 +874,7 @@ Apart from these, the following properties are also available, and may be useful #### Cluster Managers -Each cluster manager in Spark has additional configuration options. Configurations +Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: * [YARN](running-on-yarn.html#configuration) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e05883072bfa8..45b70b1a5457a 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -33,6 +33,7 @@ dependencies, and can support different cluster managers and deploy modes that S --class --master \ --deploy-mode \ + --conf = \ ... # other options \ [application-arguments] @@ -43,6 +44,7 @@ Some of the commonly used options are: * `--class`: The entry point for your application (e.g. `org.apache.spark.examples.SparkPi`) * `--master`: The [master URL](#master-urls) for the cluster (e.g. `spark://23.195.26.187:7077`) * `--deploy-mode`: Whether to deploy your driver on the worker nodes (`cluster`) or locally as an external client (`client`) (default: `client`)* +* `--conf`: Arbitrary Spark configuration property in key=value format. For values that contain spaces wrap "key=value" in quotes (as shown). * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 5ea2e5549d7df..4eacc47da5699 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -63,7 +63,8 @@ class TwitterReceiver( storageLevel: StorageLevel ) extends Receiver[Status](storageLevel) with Logging { - private var twitterStream: TwitterStream = _ + @volatile private var twitterStream: TwitterStream = _ + @volatile private var stopped = false def onStart() { try { @@ -78,7 +79,9 @@ class TwitterReceiver( def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} def onException(e: Exception) { - restart("Error receiving tweets", e) + if (!stopped) { + restart("Error receiving tweets", e) + } } }) @@ -91,12 +94,14 @@ class TwitterReceiver( } setTwitterStream(newTwitterStream) logInfo("Twitter receiver started") + stopped = false } catch { case e: Exception => restart("Error starting Twitter stream", e) } } def onStop() { + stopped = true setTwitterStream(null) logInfo("Twitter receiver stopped") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index eea9fe9520caa..1948c978c30bf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: Kryo) { kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[RoutingTableMessage]) kryo.register(classOf[(VertexId, Object)]) kryo.register(classOf[EdgePartition[Object, Object]]) kryo.register(classOf[BitSet]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 502b112d31c2e..a565d3b28bf52 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -/** - * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that - * the edge partition references `vid` in the specified `position` (src, dst, or both). -*/ -private[graphx] -class RoutingTableMessage( - var vid: VertexId, - var pid: PartitionID, - var position: Byte) - extends Product2[VertexId, (PartitionID, Byte)] with Serializable { - override def _1 = vid - override def _2 = (pid, position) - override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage] -} +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } @@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions { private[graphx] object RoutingTablePartition { + /** + * A message from an edge partition to a vertex specifying the position in which the edge + * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower + * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + */ + type RoutingTableMessage = (VertexId, Int) + + private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = { + val positionUpper2 = position << 30 + val pidLower30 = pid & 0x3FFFFFFF + (vid, positionUpper2 | pidLower30) + } + + private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1 + private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF + private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte + val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty) /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */ @@ -77,7 +81,9 @@ object RoutingTablePartition { map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => - new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2) + val vid = vidAndPosition._1 + val position = vidAndPosition._2 + toMessage(vid, pid, position) } } @@ -88,9 +94,12 @@ object RoutingTablePartition { val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) for (msg <- iter) { - pid2vid(msg.pid) += msg.vid - srcFlags(msg.pid) += (msg.position & 0x1) != 0 - dstFlags(msg.pid) += (msg.position & 0x2) != 0 + val vid = vidFromMessage(msg) + val pid = pidFromMessage(msg) + val position = positionFromMessage(msg) + pid2vid(pid) += vid + srcFlags(pid) += (position & 0x1) != 0 + dstFlags(pid) += (position & 0x2) != 0 } new RoutingTablePartition(pid2vid.zipWithIndex.map { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 2d98c24d6970e..3909efcdfc993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -24,9 +24,11 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.graphx._ import org.apache.spark.serializer._ +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage + private[graphx] class RoutingTableMessageSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { @@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleSerializationStream(s) { def writeObject[T: ClassTag](t: T): SerializationStream = { val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg.vid, optimizePositive = false) - writeUnsignedVarInt(msg.pid) - // TODO: Write only the bottom two bits of msg.position - s.write(msg.position) + writeVarLong(msg._1, optimizePositive = false) + writeInt(msg._2) this } } @@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleDeserializationStream(s) { override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - val c = s.read() - if (c == -1) throw new EOFException - new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T] + val b = readInt() + (a, b).asInstanceOf[T] } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index ff17edeaf8f16..6aab28ff05355 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -30,7 +30,7 @@ package object graphx { */ type VertexId = Long - /** Integer identifer of a graph partition. */ + /** Integer identifer of a graph partition. Must be less than 2^30. */ // TODO: Consider using Char. type PartitionID = Int diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 9d16182f9d8c4..94db1dc183230 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + + // TODO: move utility functions to TestingUtils. + + def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { + actual.zip(expected).forall { case (x1, x2) => + x1.almostEquals(x2) + } + } + + def elementsAlmostEqual( + actual: Seq[(Double, Double)], + expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { + actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => + x1.almostEquals(x2) && y1.almostEquals(y2) + } + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) @@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().toSeq === threshold) - assert(metrics.roc().collect().toSeq === rocCurve) - assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) - assert(metrics.pr().collect().toSeq === prCurve) - assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) - assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) - assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) - assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) - assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) + assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) + assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) + assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) + assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) + assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) + assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5e5ddd227aab6..e9220db6b1f9a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -32,108 +32,83 @@ import com.typesafe.tools.mima.core._ */ object MimaExcludes { - def excludes(version: String) = version match { - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - closures.map(method => ProblemFilters.exclude[MissingMethodProblem](method)) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugChildren$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$firstDebugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$shuffleDebugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" - + "createZero$1") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } - - private val closures = Seq( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$mergeMaps$1", - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$countPartition$1", - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$distributePartition$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeValue$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeToFile$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$reducePartition$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeShard$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeCombiners$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$process$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$createCombiner$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeMaps$1" - ) + def excludes(version: String) = + version match { + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } } diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b50590ab3b444..b4c82f519bd53 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -100,6 +100,12 @@ def set(self, key, value): self._jconf.set(key, unicode(value)) return self + def setIfMissing(self, key, value): + """Set a configuration property, if not already set.""" + if self.get(key) is None: + self.set(key, value) + return self + def setMaster(self, value): """Set master URL to connect to.""" self._jconf.setMaster(value) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e21be0e10a3f7..024fb881877c9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -101,7 +101,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - + self._conf.setIfMissing("spark.rdd.compress", "true") # Set any parameters passed directly to us on the conf if master: self._conf.setMaster(master) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 94ba22306afbd..a38dd0b9237c5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -231,10 +231,10 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY}). + Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True - self._jrdd.cache() + self.persist(StorageLevel.MEMORY_ONLY_SER) return self def persist(self, storageLevel): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c7188469bfb86..02bdb64f308a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ - /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing * when all relations are already filled in and the analyser needs only to resolve attribute @@ -54,6 +53,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: + UnresolvedHavingClauseAttributes :: typeCoercionRules :_*), Batch("Check Analysis", Once, CheckResolution), @@ -151,6 +151,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * This rule finds expressions in HAVING clause filters that depend on + * unresolved attributes. It pushes these expressions down to the underlying + * aggregates and then projects them away above the filter. + */ + object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + + Project(aggregate.output, + Filter(evaluatedCondition.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } + + } + + protected def containsAggregate(condition: Expression): Boolean = + condition + .collect { case ae: AggregateExpression => ae } + .nonEmpty + } + /** * When a SELECT clause has only a single expression and that expression is a * [[catalyst.expressions.Generator Generator]] we convert the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 76ddeba9cb312..9887856b9c1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -231,10 +231,20 @@ trait HiveTypeCoercion { * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ object BooleanComparisons extends Rule[LogicalPlan] { + val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_)) + val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_)) + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change EqualTo operators as that actually makes sense for boolean types. + + // Hive treats (true = 1) as true and (false = 0) as true. + case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l + case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r + case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) + case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) + + // No need to change other EqualTo operators as that actually makes sense for boolean types. case e: EqualTo => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1b503b957d146..15c98efbcabcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -79,8 +79,24 @@ package object dsl { def === (other: Expression) = EqualTo(expr, other) def !== (other: Expression) = Not(EqualTo(expr, other)) + def in(list: Expression*) = In(expr, list) + def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def contains(other: Expression) = Contains(expr, other) + def startsWith(other: Expression) = StartsWith(expr, other) + def endsWith(other: Expression) = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + + def isNull = IsNull(expr) + def isNotNull = IsNotNull(expr) + + def getItem(ordinal: Expression) = GetItem(expr, ordinal) + def getField(fieldName: String) = GetField(expr, fieldName) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) @@ -112,6 +128,7 @@ package object dsl { def sumDistinct(e: Expression) = SumDistinct(e) def count(e: Expression) = Count(e) def countDistinct(e: Expression*) = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) def avg(e: Expression) = Average(e) def first(e: Expression) = First(e) def min(e: Expression) = Min(e) @@ -163,6 +180,18 @@ package object dsl { /** Creates a new AttributeReference of type binary */ def binary = AttributeReference(s, BinaryType, nullable = true)() + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): AttributeReference = + map(MapType(keyType, valueType)) + def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) + def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index db1ae29d400c6..c3f5c26fdbe59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite { val c3 = 'a.boolean.at(2) val c4 = 'a.boolean.at(3) - checkEvaluation(IsNull(c1), false, row) - checkEvaluation(IsNotNull(c1), true, row) + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) - checkEvaluation(IsNull(c2), true, row) - checkEvaluation(IsNotNull(c2), false, row) + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(IsNull(Literal(1, ShortType)), false) - checkEvaluation(IsNotNull(Literal(1, ShortType)), true) + checkEvaluation(Literal(1, ShortType).isNull, false) + checkEvaluation(Literal(1, ShortType).isNotNull, true) - checkEvaluation(IsNull(Literal(null, ShortType)), true) - checkEvaluation(IsNotNull(Literal(null, ShortType)), false) + checkEvaluation(Literal(null, ShortType).isNull, true) + checkEvaluation(Literal(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) @@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(If(Literal(false, BooleanType), Literal("a", StringType), Literal("b", StringType)), "b", row) - checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite { assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + + checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) + checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) + checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) } test("arithmetic") { @@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite { val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) - checkEvaluation(Contains(c1, "b"), true, row) - checkEvaluation(Contains(c1, "x"), false, row) - checkEvaluation(Contains(c2, "b"), null, row) - checkEvaluation(Contains(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal(null, StringType), null, row) - checkEvaluation(StartsWith(c1, "a"), true, row) - checkEvaluation(StartsWith(c1, "b"), false, row) - checkEvaluation(StartsWith(c2, "a"), null, row) - checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal(null, StringType), null, row) - checkEvaluation(EndsWith(c1, "c"), true, row) - checkEvaluation(EndsWith(c1, "b"), false, row) - checkEvaluation(EndsWith(c2, "b"), null, row) - checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal(null, StringType), null, row) } test("Substring") { @@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite { assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 34b355e906695..34654447a5f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -24,10 +24,10 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.AllScalaRegistrar +import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -48,22 +48,41 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } } -private[sql] object SparkSqlSerializer { - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - @transient lazy val ser: KryoSerializer = { +private[execution] class KryoResourcePool(size: Int) + extends ResourcePool[SerializerInstance](size) { + + val ser: KryoSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + // TODO (lian) Using KryoSerializer here is workaround, needs further investigation + // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization + // related error. new KryoSerializer(sparkConf) } - def serialize[T: ClassTag](o: T): Array[Byte] = { - ser.newInstance().serialize(o).array() - } + def newInstance() = ser.newInstance() +} - def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) +private[sql] object SparkSqlSerializer { + @transient lazy val resourcePool = new KryoResourcePool(30) + + private[this] def acquireRelease[O](fn: SerializerInstance => O): O = { + val kryo = resourcePool.borrow + try { + fn(kryo) + } finally { + resourcePool.release(kryo) + } } + + def serialize[T: ClassTag](o: T): Array[Byte] = + acquireRelease { k => + k.serialize(o).array() + } + + def deserialize[T: ClassTag](bytes: Array[Byte]): T = + acquireRelease { k => + k.deserialize[T](ByteBuffer.wrap(bytes)) + } } private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index a3fac2a5adbb9..85396f26142e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.json -import scala.collection.JavaConversions._ +import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import com.fasterxml.jackson.databind.ObjectMapper @@ -218,12 +218,12 @@ private[sql] object JsonRDD extends Logging { case (k, dataType) => (s"$key.$k", dataType) } ++ Set((key, StructType(Nil))) } - case (key: String, array: List[_]) => { + case (key: String, array: Seq[_]) => { // The value associated with the key is an array. typeOfArray(array) match { case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. - array.asInstanceOf[List[Map[String, Any]]].flatMap { + array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) @@ -238,7 +238,7 @@ private[sql] object JsonRDD extends Logging { } /** - * Converts a Java Map/List to a Scala Map/List. + * Converts a Java Map/List to a Scala Map/Seq. * We do not use Jackson's scala module at here because * DefaultScalaModule in jackson-module-scala will make * the parsing very slow. @@ -248,9 +248,9 @@ private[sql] object JsonRDD extends Logging { // .map(identity) is used as a workaround of non-serializable Map // generated by .mapValues. // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - map.toMap.mapValues(scalafy).map(identity) + JMapWrapper(map).mapValues(scalafy).map(identity) case list: java.util.List[_] => - list.toList.map(scalafy) + JListWrapper(list).map(scalafy) case atom => atom } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index c8ea01c4e1b6a..1a6a6c17473a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test._ /* Implicits */ @@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, Sum('b)), + testData2.groupBy('a)('a, sum('b)), Seq((1,3),(2,3),(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)), + testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), 9 ) checkAnswer( - testData2.aggregate(Sum('b)), + testData2.aggregate(sum('b)), 9 ) } @@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest { Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) checkAnswer( - arrayData.orderBy(GetItem('data, 0).asc), + arrayData.orderBy('data.getItem(0).asc), arrayData.collect().sortBy(_.data(0)).toSeq) checkAnswer( - arrayData.orderBy(GetItem('data, 0).desc), + arrayData.orderBy('data.getItem(0).desc), arrayData.collect().sortBy(_.data(0)).reverse.toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).asc), + mapData.orderBy('data.getItem(1).asc), mapData.collect().sortBy(_.data(1)).toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).desc), + mapData.orderBy('data.getItem(1).desc), mapData.collect().sortBy(_.data(1)).reverse.toSeq) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala similarity index 99% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd036faaa6354..8b451973a47a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -391,6 +391,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_8", "groupby_sort_9", "groupby_sort_test_1", + "having", + "having1", "implicit_cast1", "innerjoin", "inoutdriver", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index f30ae28b81e06..1699ffe06ce15 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -102,6 +102,36 @@ test
+ + + + hive + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + compatibility/src/test/scala + + + + + + + + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala new file mode 100644 index 0000000000000..28b1a43d85773 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.{io => hadoopIo} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types +import org.apache.spark.sql.catalyst.types._ + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[hive] trait HiveInspectors { + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + // writable + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType + + // java class + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[Array[Byte]] => BinaryType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + + // primitive type + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get + case b: hiveIo.ByteWritable => b.get + case b: hadoopIo.FloatWritable => b.get + case b: hadoopIo.BytesWritable => { + val bytes = new Array[Byte](b.getLength) + System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) + bytes + } + case t: hiveIo.TimestampWritable => t.getTimestamp + case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + case p: java.math.BigDecimal => p + case p: Array[Byte] => p + case p: java.sql.Timestamp => p + } + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case hvoi: HiveVarcharObjectInspector => + if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue + case hdoi: HiveDecimalObjectInspector => + if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case f: Float => f: java.lang.Float + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case b: BigDecimal => b.bigDecimal + case b: Array[Byte] => b + case t: java.sql.Timestamp => t + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + case _: WritableFloatObjectInspector => FloatType + case _: JavaFloatObjectInspector => FloatType + case _: WritableBinaryObjectInspector => BinaryType + case _: JavaBinaryObjectInspector => BinaryType + case _: WritableHiveDecimalObjectInspector => DecimalType + case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableTimestampObjectInspector => TimestampType + case _: JavaTimestampObjectInspector => TimestampType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case TimestampType => timestampTypeInfo + case NullType => voidTypeInfo + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53480a521dd14..c4ca9f362a04d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command -private[hive] case class AddJar(jarPath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ @@ -229,7 +227,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8)) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d216450e04557..057eb60a02612 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry - extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { +private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( sys.error(s"Couldn't find function $name")) + val functionClassName = functionInfo.getFunctionClass.getName() + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = createFunction[UDF](name) + val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) HiveSimpleUdf( - name, + functionClassName, children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(name, children) + HiveGenericUdf(functionClassName, children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(name, children) + HiveGenericUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(name, Nil, children) + HiveGenericUdtf(functionClassName, Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } - - def javaClassToDataType(clz: Class[_]): DataType = clz match { - // writable - case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType - case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType - case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType - case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType - case c: Class[_] if c == classOf[hadoopIo.Text] => StringType - case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType - case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType - case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType - case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType - case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - - // java class - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType - case c: Class[_] if c == classOf[Array[Byte]] => BinaryType - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - - // primitive type - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) - } } private[hive] trait HiveFunctionFactory { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) - def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass - def createFunction[UDFType](name: String) = - getFunctionClass(name).newInstance.asInstanceOf[UDFType] - - /** Converts hive types to native catalyst types. */ - def unwrap(a: Any): Any = a match { - case null => null - case i: hadoopIo.IntWritable => i.get - case t: hadoopIo.Text => t.toString - case l: hadoopIo.LongWritable => l.get - case d: hadoopIo.DoubleWritable => d.get - case d: hiveIo.DoubleWritable => d.get - case s: hiveIo.ShortWritable => s.get - case b: hadoopIo.BooleanWritable => b.get - case b: hiveIo.ByteWritable => b.get - case b: hadoopIo.FloatWritable => b.get - case b: hadoopIo.BytesWritable => { - val bytes = new Array[Byte](b.getLength) - System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) - bytes - } - case t: hiveIo.TimestampWritable => t.getTimestamp - case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) - case list: java.util.List[_] => list.map(unwrap) - case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap - case array: Array[_] => array.map(unwrap).toSeq - case p: java.lang.Short => p - case p: java.lang.Long => p - case p: java.lang.Float => p - case p: java.lang.Integer => p - case p: java.lang.Double => p - case p: java.lang.Byte => p - case p: java.lang.Boolean => p - case str: String => str - case p: java.math.BigDecimal => p - case p: Array[Byte] => p - case p: java.sql.Timestamp => p - } + val functionClassName: String + + def createFunction[UDFType]() = + getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] } private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { @@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type UDFType type EvaluatedType = Any - val name: String - def nullable = true def references = children.flatMap(_.references).toSet - // FunctionInfo is not serializable so we must look it up here again. - lazy val functionInfo = getFunctionInfo(name) - lazy val function = createFunction[UDFType](name) + lazy val function = createFunction[UDFType]() - override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } -private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { +private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) + extends HiveUdf { + import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) } } -private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) +private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ @@ -277,132 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) } } -private[hive] trait HiveInspectors { - - def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { - case hvoi: HiveVarcharObjectInspector => - if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue - case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) - case li: ListObjectInspector => - Option(li.getList(data)) - .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) - .orNull - case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k,v) => - (unwrapData(k, mi.getMapKeyObjectInspector), - unwrapData(v, mi.getMapValueObjectInspector)) - }.toMap).orNull - case si: StructObjectInspector => - val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) - } - - /** Converts native catalyst types to the types expected by Hive */ - def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal - case b: Array[Byte] => b - case t: java.sql.Timestamp => t - case s: Seq[_] => seqAsJavaList(s.map(wrap)) - case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) - case null => null - } - - def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe, _) => - ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => - ObjectInspectorFactory.getStandardMapObjectInspector( - toInspector(keyType), toInspector(valueType)) - case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector - case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector - case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector - case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector - case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector - case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector - case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector - case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector - case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector - case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector - case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => - ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) - } - - def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { - case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { - types.StructField( - f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) - case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) - case m: MapObjectInspector => - MapType( - inspectorToDataType(m.getMapKeyObjectInspector), - inspectorToDataType(m.getMapValueObjectInspector)) - case _: WritableStringObjectInspector => StringType - case _: JavaStringObjectInspector => StringType - case _: WritableIntObjectInspector => IntegerType - case _: JavaIntObjectInspector => IntegerType - case _: WritableDoubleObjectInspector => DoubleType - case _: JavaDoubleObjectInspector => DoubleType - case _: WritableBooleanObjectInspector => BooleanType - case _: JavaBooleanObjectInspector => BooleanType - case _: WritableLongObjectInspector => LongType - case _: JavaLongObjectInspector => LongType - case _: WritableShortObjectInspector => ShortType - case _: JavaShortObjectInspector => ShortType - case _: WritableByteObjectInspector => ByteType - case _: JavaByteObjectInspector => ByteType - case _: WritableFloatObjectInspector => FloatType - case _: JavaFloatObjectInspector => FloatType - case _: WritableBinaryObjectInspector => BinaryType - case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType - case _: WritableTimestampObjectInspector => TimestampType - case _: JavaTimestampObjectInspector => TimestampType - } - - implicit class typeInfoConversions(dt: DataType) { - import org.apache.hadoop.hive.serde2.typeinfo._ - import TypeInfoFactory._ - - def toTypeInfo: TypeInfo = dt match { - case BinaryType => binaryTypeInfo - case BooleanType => booleanTypeInfo - case ByteType => byteTypeInfo - case DoubleType => doubleTypeInfo - case FloatType => floatTypeInfo - case IntegerType => intTypeInfo - case LongType => longTypeInfo - case ShortType => shortTypeInfo - case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo - case TimestampType => timestampTypeInfo - case NullType => voidTypeInfo - } - } -} - private[hive] case class HiveGenericUdaf( - name: String, + functionClassName: String, children: Seq[Expression]) extends AggregateExpression with HiveInspectors with HiveFunctionFactory { @@ -410,7 +207,7 @@ private[hive] case class HiveGenericUdaf( type UDFType = AbstractGenericUDAFResolver @transient - protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction() @transient protected lazy val objectInspector = { @@ -427,9 +224,9 @@ private[hive] case class HiveGenericUdaf( def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" - def newInstance() = new HiveUdafFunction(name, children, this) + def newInstance() = new HiveUdafFunction(functionClassName, children, this) } /** @@ -444,7 +241,7 @@ private[hive] case class HiveGenericUdaf( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUdtf( - name: String, + functionClassName: String, aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { @@ -452,7 +249,7 @@ private[hive] case class HiveGenericUdtf( override def references = children.flatMap(_.references).toSet @transient - protected lazy val function: GenericUDTF = createFunction(name) + protected lazy val function: GenericUDTF = createFunction() protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @@ -507,11 +304,11 @@ private[hive] case class HiveGenericUdtf( } } - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } private[hive] case class HiveUdafFunction( - functionName: String, + functionClassName: String, exprs: Seq[Expression], base: AggregateExpression) extends AggregateFunction @@ -520,7 +317,7 @@ private[hive] case class HiveUdafFunction( def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + private val resolver = createFunction[AbstractGenericUDAFResolver]() private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 new file mode 100644 index 0000000000000..4d1ebdcde2c71 --- /dev/null +++ b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 @@ -0,0 +1 @@ +true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false diff --git a/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 b/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 new file mode 100644 index 0000000000000..704f1e62f14c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 @@ -0,0 +1,10 @@ +4 +4 +5 +4 +5 +5 +4 +4 +5 +4 diff --git a/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 b/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e new file mode 100644 index 0000000000000..b56757a60f780 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e @@ -0,0 +1,308 @@ +0 val_0 +2 val_2 +4 val_4 +5 val_5 +8 val_8 +9 val_9 +10 val_10 +11 val_11 +12 val_12 +15 val_15 +17 val_17 +18 val_18 +19 val_19 +20 val_20 +24 val_24 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +100 val_100 +103 val_103 +104 val_104 +105 val_105 +111 val_111 +113 val_113 +114 val_114 +116 val_116 +118 val_118 +119 val_119 +120 val_120 +125 val_125 +126 val_126 +128 val_128 +129 val_129 +131 val_131 +133 val_133 +134 val_134 +136 val_136 +137 val_137 +138 val_138 +143 val_143 +145 val_145 +146 val_146 +149 val_149 +150 val_150 +152 val_152 +153 val_153 +155 val_155 +156 val_156 +157 val_157 +158 val_158 +160 val_160 +162 val_162 +163 val_163 +164 val_164 +165 val_165 +166 val_166 +167 val_167 +168 val_168 +169 val_169 +170 val_170 +172 val_172 +174 val_174 +175 val_175 +176 val_176 +177 val_177 +178 val_178 +179 val_179 +180 val_180 +181 val_181 +183 val_183 +186 val_186 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +192 val_192 +193 val_193 +194 val_194 +195 val_195 +196 val_196 +197 val_197 +199 val_199 +200 val_200 +201 val_201 +202 val_202 +203 val_203 +205 val_205 +207 val_207 +208 val_208 +209 val_209 +213 val_213 +214 val_214 +216 val_216 +217 val_217 +218 val_218 +219 val_219 +221 val_221 +222 val_222 +223 val_223 +224 val_224 +226 val_226 +228 val_228 +229 val_229 +230 val_230 +233 val_233 +235 val_235 +237 val_237 +238 val_238 +239 val_239 +241 val_241 +242 val_242 +244 val_244 +247 val_247 +248 val_248 +249 val_249 +252 val_252 +255 val_255 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 b/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff new file mode 100644 index 0000000000000..2d7022e386303 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff @@ -0,0 +1,199 @@ +4 +5 +8 +9 +26 +27 +28 +30 +33 +34 +35 +37 +41 +42 +43 +44 +47 +51 +53 +54 +57 +58 +64 +65 +66 +67 +69 +70 +72 +74 +76 +77 +78 +80 +82 +83 +84 +85 +86 +87 +90 +92 +95 +96 +97 +98 +256 +257 +258 +260 +262 +263 +265 +266 +272 +273 +274 +275 +277 +278 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +291 +292 +296 +298 +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 b/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e new file mode 100644 index 0000000000000..bd545ccf7430c --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e @@ -0,0 +1,125 @@ +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 b/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff new file mode 100644 index 0000000000000..d77586c12b6af --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff @@ -0,0 +1,199 @@ +4 val_4 +5 val_5 +8 val_8 +9 val_9 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +302 val_302 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index eb7df717284ce..6f36a4f8cb905 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,18 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("boolean = number", + """ + |SELECT + | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y, + | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y, + | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y, + | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y, + | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y, + | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y + |FROM src LIMIT 1 + """.stripMargin) + test("CREATE TABLE AS runs once") { hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 03a73f92b275e..566983675bff5 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -99,9 +99,25 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } + /** Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ + def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { + try { + Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) + .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) + } catch { + case t: Throwable => + println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + Seq.empty[String] + } + } + private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ + getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -121,7 +137,8 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") + name.contains("Hive") || + name.contains("repl") } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3ec36487dcd26..62b5c3bc5f0f3 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -60,6 +60,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) private var isLastAMRetry: Boolean = true @@ -237,6 +238,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (null != sparkContext) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, resourceManager, @@ -360,7 +362,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setTrackingUrl(uiHistoryAddress) resourceManager.finishApplicationMaster(finishReq) } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a86ad256dfa39..184e2ad6c82cd 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -28,7 +28,6 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -57,10 +56,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false + + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) val securityManager = new SecurityManager(sparkConf) - val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ @@ -97,23 +103,26 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - - // Compute number of threads for akka - val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - - if (minimumMemory > 0) { - val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore + synchronized { + if (!isFinished) { + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + registered = true } } - waitForSparkMaster() addAmIpFilter() // Allocate all containers @@ -243,11 +252,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) + checkNumExecutorsFailed() Thread.sleep(100) } logInfo("All executors have launched.") - + } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } } // TODO: We might want to extend this to allocate more containers in case they die ! @@ -257,6 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating " + missingExecutorCount + @@ -282,15 +298,23 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.allocateContainers(0) } - def finishApplicationMaster(status: FinalApplicationStatus) { - - logInfo("finish ApplicationMaster with " + status) - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) - resourceManager.finishApplicationMaster(finishReq) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setDiagnostics(appMessage) + resourceManager.finishApplicationMaster(finishReq) + } + isFinished = true + } } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 718cb19f57261..e98308cdbd74e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,6 +30,9 @@ import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.SparkHadoopUtil /** @@ -132,4 +135,17 @@ object YarnSparkHadoopUtil { } } + def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = { + val eventLogDir = sc.eventLogger match { + case Some(logger) => logger.getApplicationLogDir() + case None => "" + } + val historyServerAddress = conf.get("spark.yarn.historyServer.address", "") + if (historyServerAddress != "" && eventLogDir != "") { + historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir" + } else { + "" + } + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d8266f7b0c9a7..f8fb96b312f23 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -37,6 +37,8 @@ private[spark] class YarnClientSchedulerBackend( var client: Client = null var appId: ApplicationId = null + var checkerThread: Thread = null + var stopping: Boolean = false private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -54,6 +56,7 @@ private[spark] class YarnClientSchedulerBackend( val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) + conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( @@ -85,6 +88,7 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.runApp() waitForApp() + checkerThread = yarnApplicationStateCheckerThread() } def waitForApp() { @@ -115,7 +119,32 @@ private[spark] class YarnClientSchedulerBackend( } } + private def yarnApplicationStateCheckerThread(): Thread = { + val t = new Thread { + override def run() { + while (!stopping) { + val report = client.getApplicationReport(appId) + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED + || state == YarnApplicationState.FAILED) { + logError(s"Yarn application already ended: $state") + sc.stop() + stopping = true + } + Thread.sleep(1000L) + } + checkerThread = null + Thread.currentThread().interrupt() + } + } + t.setName("Yarn Application State Checker") + t.setDaemon(true) + t.start() + t + } + override def stop() { + stopping = true super.stop() client.stop logInfo("Stopped") diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index eaf594c8b49b9..035356d390c80 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -59,6 +59,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt( YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) private var isLastAMRetry: Boolean = true @@ -216,6 +217,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (sparkContext != null) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, amClient, @@ -312,8 +314,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, logInfo("Unregistering ApplicationMaster with " + status) if (registered) { - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl) + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 5ac95f3798723..fc7b8320d734d 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -19,15 +19,12 @@ package org.apache.spark.deploy.yarn import java.net.Socket import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -57,10 +54,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false private var amClient: AMRMClient[ContainerRequest] = _ + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 @@ -101,7 +104,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp amClient.start() appAttemptId = ApplicationMaster.getApplicationAttemptId() - registerApplicationMaster() + synchronized { + if (!isFinished) { + registerApplicationMaster() + registered = true + } + } waitForSparkMaster() addAmIpFilter() @@ -210,6 +218,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(100) @@ -228,12 +237,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } + } + private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") yarnAllocator.allocateResources() @@ -248,10 +265,18 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp t } - def finishApplicationMaster(status: FinalApplicationStatus) { - logInfo("Unregistering ApplicationMaster with " + status) - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") + amClient.unregisterApplicationMaster(status, appMessage, trackingUrl) + } + isFinished = true + } } }