diff --git a/assembly/pom.xml b/assembly/pom.xml
index 5ec9da22ae83f..31a01e4d8e1de 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -349,5 +349,15 @@
+
+ kinesis-asl
+
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
+
+
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 75ea535f2f57b..e8f761eaa5799 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -83,6 +83,15 @@ trait FutureAction[T] extends Future[T] {
*/
@throws(classOf[Exception])
def get(): T = Await.result(this, Duration.Inf)
+
+ /**
+ * Returns the job IDs run by the underlying async operation.
+ *
+ * This returns the current snapshot of the job list. Certain operations may run multiple
+ * jobs, so multiple calls to this method may return different lists.
+ */
+ def jobIds: Seq[Int]
+
}
@@ -150,8 +159,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
}
- /** Get the corresponding job id for this action. */
- def jobId = jobWaiter.jobId
+ def jobIds = Seq(jobWaiter.jobId)
}
@@ -171,6 +179,8 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false
+ @volatile private var jobs: Seq[Int] = Nil
+
// A promise used to signal the future.
private val p = promise[T]()
@@ -219,6 +229,8 @@ class ComplexFutureAction[T] extends FutureAction[T] {
}
}
+ this.jobs = jobs ++ job.jobIds
+
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
@@ -255,4 +267,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
override def isCompleted: Boolean = p.isCompleted
override def value: Option[Try[T]] = p.future.value
+
+ def jobIds = jobs
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 21d0cc7b5cbaa..6b63eb23e9ee1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -23,6 +23,7 @@ import java.io.EOFException
import scala.collection.immutable.Map
import scala.reflect.ClassTag
+import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
+import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
/**
@@ -249,9 +251,21 @@ class HadoopRDD[K, V](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- // TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopPartition]
- hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
+ val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
+ val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e: Exception =>
+ logDebug("Failed to use InputSplitWithLocations.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
override def checkpoint() {
@@ -261,7 +275,7 @@ class HadoopRDD[K, V](
def getConf: Configuration = getJobConf()
}
-private[spark] object HadoopRDD {
+private[spark] object HadoopRDD extends Logging {
/** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
@@ -309,4 +323,42 @@ private[spark] object HadoopRDD {
f(inputSplit, firstParent[T].iterator(split, context))
}
}
+
+ private[spark] class SplitInfoReflections {
+ val inputSplitWithLocationInfo =
+ Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
+ val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
+ val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val isInMemory = splitLocationInfo.getMethod("isInMemory")
+ val getLocation = splitLocationInfo.getMethod("getLocation")
+ }
+
+ private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
+ Some(new SplitInfoReflections)
+ } catch {
+ case e: Exception =>
+ logDebug("SplitLocationInfo and other new Hadoop classes are " +
+ "unavailable. Using the older Hadoop location info code.", e)
+ None
+ }
+
+ private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
+ val out = ListBuffer[String]()
+ infos.foreach { loc => {
+ val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
+ getLocation.invoke(loc).asInstanceOf[String]
+ if (locationStr != "localhost") {
+ if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
+ invoke(loc).asInstanceOf[Boolean]) {
+ logDebug("Partition " + locationStr + " is cached by Hadoop.")
+ out += new HDFSCacheTaskLocation(locationStr).toString
+ } else {
+ out += new HostTaskLocation(locationStr).toString
+ }
+ }
+ }}
+ out.seq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 4c84b3f62354d..0cccdefc5ee09 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -173,9 +173,21 @@ class NewHadoopRDD[K, V](
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val theSplit = split.asInstanceOf[NewHadoopPartition]
- theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ override def getPreferredLocations(hsplit: Partition): Seq[String] = {
+ val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
+ val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e : Exception =>
+ logDebug("Failed to use InputSplit#getLocationInfo.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
def getConf: Configuration = confBroadcast.value.value
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab9e97c8fe409..2aba40d152e3e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * Get the preferred locations of a partition, taking into account whether the
* RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5a96f52a10cd4..8135cdbb4c31f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1303,7 +1303,7 @@ class DAGScheduler(
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (!rddPrefs.isEmpty) {
- return rddPrefs.map(host => TaskLocation(host))
+ return rddPrefs.map(TaskLocation(_))
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index 67c9a6760b1b3..10c685f29d3ac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -22,13 +22,51 @@ package org.apache.spark.scheduler
* In the latter case, we will prefer to launch the task on that executorID, but our next level
* of preference will be executors on the same host if this is not possible.
*/
-private[spark]
-class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
- override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+private[spark] sealed trait TaskLocation {
+ def host: String
+}
+
+/**
+ * A location that includes both a host and an executor id on that host.
+ */
+private [spark] case class ExecutorCacheTaskLocation(override val host: String,
+ val executorId: String) extends TaskLocation {
+}
+
+/**
+ * A location on a host.
+ */
+private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
+ override def toString = host
+}
+
+/**
+ * A location on a host that is cached by HDFS.
+ */
+private [spark] case class HDFSCacheTaskLocation(override val host: String)
+ extends TaskLocation {
+ override def toString = TaskLocation.inMemoryLocationTag + host
}
private[spark] object TaskLocation {
- def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+ // We identify hosts on which the block is cached with this prefix. Because this prefix contains
+ // underscores, which are not legal characters in hostnames, there should be no potential for
+ // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames.
+ val inMemoryLocationTag = "hdfs_cache_"
+
+ def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId)
- def apply(host: String) = new TaskLocation(host, None)
+ /**
+ * Create a TaskLocation from a string returned by getPreferredLocations.
+ * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
+ * location is cached.
+ */
+ def apply(str: String) = {
+ val hstr = str.stripPrefix(inMemoryLocationTag)
+ if (hstr.equals(str)) {
+ new HostTaskLocation(str)
+ } else {
+ new HostTaskLocation(hstr)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index d9d53faf843ff..a6c23fc85a1b0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -181,8 +181,24 @@ private[spark] class TaskSetManager(
}
for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ loc match {
+ case e: ExecutorCacheTaskLocation =>
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
+ case e: HDFSCacheTaskLocation => {
+ val exe = sched.getExecutorsAliveOnHost(loc.host)
+ exe match {
+ case Some(set) => {
+ for (e <- set) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
+ }
+ logInfo(s"Pending task $index has a cached location at ${e.host} " +
+ ", where there are executors " + set.mkString(","))
+ }
+ case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
+ ", but there are no executors alive there.")
+ }
+ }
+ case _ => Unit
}
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
@@ -283,7 +299,10 @@ private[spark] class TaskSetManager(
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if canRunOnHost(index)) {
val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
+ val executors = prefs.flatMap(_ match {
+ case e: ExecutorCacheTaskLocation => Some(e.executorId)
+ case _ => None
+ });
if (executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e5b83c069d961..b3025c6ec3364 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1439,7 +1439,7 @@ private[spark] object Utils extends Logging {
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
for (offset <- 0 to maxRetries) {
// Do not increment port if startPort is 0, which is treated as a special port
- val tryPort = if (startPort == 0) startPort else (startPort + offset) % (65536 - 1024) + 1024
+ val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536
try {
val (service, port) = startService(tryPort)
logInfo(s"Successfully started service$serviceString on port $port.")
diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
new file mode 100644
index 0000000000000..db9c25fc457a4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+
+import org.apache.spark.SparkContext._
+
+class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext {
+
+ before {
+ sc = new SparkContext("local", "FutureActionSuite")
+ }
+
+ test("simple async action") {
+ val rdd = sc.parallelize(1 to 10, 2)
+ val job = rdd.countAsync()
+ val res = Await.result(job, Duration.Inf)
+ res should be (10)
+ job.jobIds.size should be (1)
+ }
+
+ test("complex async action") {
+ val rdd = sc.parallelize(1 to 15, 3)
+ val job = rdd.takeAsync(10)
+ val res = Await.result(job, Duration.Inf)
+ res should be (1 to 10)
+ job.jobIds.size should be (2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 93e8ddacf8865..c0b07649eb6dd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("execC", "host3", ANY) !== None)
}
+ test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") {
+ // Regression test for SPARK-2931
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc,
+ ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+ val taskSet = FakeTask.createTaskSet(3,
+ Seq(HostTaskLocation("host1")),
+ Seq(HostTaskLocation("host2")),
+ Seq(HDFSCacheTaskLocation("host3")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execA")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execB")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execC")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(ANY)))
+ }
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
diff --git a/examples/pom.xml b/examples/pom.xml
index 2b561857f9f33..eb49a0e5af22d 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -43,6 +43,11 @@
spark-streaming-kinesis-asl_${scala.binary.version}
${project.version}
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
diff --git a/pom.xml b/pom.xml
index 70cb9729ff6d3..7756c89b00cad 100644
--- a/pom.xml
+++ b/pom.xml
@@ -138,6 +138,7 @@
0.7.1
1.8.3
1.1.0
+ 4.2.6
64m
512m
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4076ebc6fc8d5..d499302124461 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -41,6 +41,8 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
Seq(
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.scheduler.TaskLocation"),
// Added normL1 and normL2 to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 862f78702c4e6..26336332c05a2 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -166,7 +166,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
val withProjection =
g.map {g =>
- Aggregate(assignAliases(g), assignAliases(p), withFilter)
+ Aggregate(g, assignAliases(p), withFilter)
}.getOrElse(Project(assignAliases(p), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index c7d73d3990c3a..ac043d4dd8eb9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -157,8 +157,18 @@ case object StringType extends NativeType with PrimitiveType {
def simpleString: String = "string"
}
-case object BinaryType extends DataType with PrimitiveType {
+case object BinaryType extends NativeType with PrimitiveType {
private[sql] type JvmType = Array[Byte]
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val ordering = new Ordering[JvmType] {
+ def compare(x: Array[Byte], y: Array[Byte]): Int = {
+ for (i <- 0 until x.length; if i < y.length) {
+ val res = x(i).compareTo(y(i))
+ if (res != 0) return res
+ }
+ return x.length - y.length
+ }
+ }
def simpleString: String = "binary"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index c2f48a902a3e9..f88099ec0761e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -37,7 +37,7 @@ trait Command {
* The `execute()` method of all the physical command classes should reference `sideEffectResult`
* so that the command can be executed eagerly right after the command query is created.
*/
- protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row]
+ protected lazy val sideEffectResult: Seq[Row] = Seq.empty[Row]
override def executeCollect(): Array[Row] = sideEffectResult.toArray
@@ -53,7 +53,7 @@ case class SetCommand(
@transient context: SQLContext)
extends LeafNode with Command with Logging {
- override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match {
+ override protected lazy val sideEffectResult: Seq[Row] = (key, value) match {
// Set value for key k.
case (Some(k), Some(v)) =>
if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
@@ -121,7 +121,7 @@ case class ExplainCommand(
extends LeafNode with Command {
// Run through the optimizer to generate the physical plan.
- override protected[sql] lazy val sideEffectResult: Seq[Row] = try {
+ override protected lazy val sideEffectResult: Seq[Row] = try {
// TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties.
val queryExecution = context.executePlan(logicalPlan)
val outputString = if (extended) queryExecution.toString else queryExecution.simpleString
@@ -141,7 +141,7 @@ case class ExplainCommand(
case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext)
extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult = {
+ override protected lazy val sideEffectResult = {
if (doCache) {
context.cacheTable(tableName)
} else {
@@ -161,7 +161,7 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])(
@transient context: SQLContext)
extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
Row("# Registered as a temporary table", null, null) +:
child.output.map(field => Row(field.name, field.dataType.toString, null))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 08376eb5e5c4e..6fb6cb8db0c8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -190,6 +190,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ checkAnswer(
+ sql("SELECT b FROM binaryData ORDER BY a ASC"),
+ (1 to 5).map(Row(_)).toSeq)
+
+ checkAnswer(
+ sql("SELECT b FROM binaryData ORDER BY a DESC"),
+ (1 to 5).map(Row(_)).toSeq.reverse)
+
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
arrayData.collect().sortBy(_.data(0)).toSeq)
@@ -672,4 +680,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
+
+ test("SPARK-3371 Renaming a function expression with group by gives error") {
+ registerFunction("len", (s: String) => s.length)
+ checkAnswer(
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index eb33a61c6e811..10b7979df7375 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -54,6 +54,16 @@ object TestData {
TestData2(3, 2) :: Nil)
testData2.registerTempTable("testData2")
+ case class BinaryData(a: Array[Byte], b: Int)
+ val binaryData: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ BinaryData("12".getBytes(), 1) ::
+ BinaryData("22".getBytes(), 5) ::
+ BinaryData("122".getBytes(), 3) ::
+ BinaryData("121".getBytes(), 2) ::
+ BinaryData("123".getBytes(), 4) :: Nil)
+ binaryData.registerTempTable("binaryData")
+
// TODO: There is no way to express null primitives as case classes currently...
val testData3 =
logical.LocalRelation('a.int, 'b.int).loadData(
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index bd3f68d92d8c7..910174a153768 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -113,7 +113,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
case ShortType =>
- to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
+ to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal)))
case TimestampType =>
to.addColumnValue(
ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
@@ -145,7 +145,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(null))
case ShortType =>
- to.addColumnValue(ColumnValue.intValue(null))
+ to.addColumnValue(ColumnValue.shortValue(null))
case TimestampType =>
to.addColumnValue(ColumnValue.timestampValue(null))
case BinaryType | _: ArrayType | _: StructType | _: MapType =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 3e1a7b71528e0..fdb56901f9ddb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -231,12 +231,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
@transient protected[hive] lazy val sessionState = {
val ss = new SessionState(hiveconf)
setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
+ SessionState.start(ss)
+ ss.err = new PrintStream(outputBuffer, true, "UTF-8")
+ ss.out = new PrintStream(outputBuffer, true, "UTF-8")
+
ss
}
- sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
- sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
-
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
runSqlHive(s"SET $key=$value")
@@ -273,7 +274,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
results
}
- SessionState.start(sessionState)
/**
* Execute the command using Hive and return the results as a sequence. Each element
@@ -404,7 +404,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// be similar with Hive.
describeHiveTableCommand.hiveString
case command: PhysicalCommand =>
- command.sideEffectResult.map(_.head.toString)
+ command.executeCollect().map(_.head.toString)
case other =>
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index fa889ec104c6e..d633c42c6bd67 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -213,6 +213,8 @@ private[hive] trait HiveInspectors {
case _: JavaHiveDecimalObjectInspector => DecimalType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
+ case _: WritableVoidObjectInspector => NullType
+ case _: JavaVoidObjectInspector => NullType
}
implicit class typeInfoConversions(dt: DataType) {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 0aa6292c0184e..4f3f808c93dc8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -855,7 +855,7 @@ private[hive] object HiveQl {
case Token("TOK_SELEXPR",
e :: Token(alias, Nil) :: Nil) =>
- Some(Alias(nodeToExpr(e), alias)())
+ Some(Alias(nodeToExpr(e), cleanIdentifier(alias))())
/* Hints are ignored */
case Token("TOK_HINTLIST", _) => None
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 70fb15259e7d7..4a999b98ad92b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -40,8 +40,10 @@ import org.apache.spark.sql.SQLConf
/* Implicit conversions */
import scala.collection.JavaConversions._
+// SPARK-3729: Test key required to check for initialization errors with config.
object TestHive
- extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf()))
+ extends TestHiveContext(
+ new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", "")))
/**
* A locally running test instance of Spark's Hive execution engine.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 1017fe6d5396d..3625708d03175 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -30,23 +30,23 @@ import org.apache.spark.sql.hive.MetastoreRelation
* Create table and insert the query result into it.
* @param database the database name of the new relation
* @param tableName the table name of the new relation
- * @param insertIntoRelation function of creating the `InsertIntoHiveTable`
+ * @param insertIntoRelation function of creating the `InsertIntoHiveTable`
* by specifying the `MetaStoreRelation`, the data will be inserted into that table.
* TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc.
*/
@Experimental
case class CreateTableAsSelect(
- database: String,
- tableName: String,
- query: SparkPlan,
- insertIntoRelation: MetastoreRelation => InsertIntoHiveTable)
- extends LeafNode with Command {
+ database: String,
+ tableName: String,
+ query: SparkPlan,
+ insertIntoRelation: MetastoreRelation => InsertIntoHiveTable)
+ extends LeafNode with Command {
def output = Seq.empty
// A lazy computing of the metastoreRelation
private[this] lazy val metastoreRelation: MetastoreRelation = {
- // Create the table
+ // Create the table
val sc = sqlContext.asInstanceOf[HiveContext]
sc.catalog.createTable(database, tableName, query.output, false)
// Get the Metastore Relation
@@ -55,7 +55,7 @@ case class CreateTableAsSelect(
}
}
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
insertIntoRelation(metastoreRelation).execute
Seq.empty[Row]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
index 317801001c7a4..106cede9788ec 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
@@ -48,7 +48,7 @@ case class DescribeHiveTableCommand(
.mkString("\t")
}
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
// Trying to mimic the format of Hive's output. But not exactly the same.
var results: Seq[(String, String, String)] = Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
index 8f10e1ba7f426..6930c2babd117 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
@@ -32,7 +32,7 @@ case class NativeCommand(
@transient context: HiveContext)
extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_))
+ override protected lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_))
override def otherCopyArgs = context :: Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index d61c5e274a596..0fc674af31885 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -37,7 +37,7 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command {
def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
hiveContext.analyze(tableName)
Seq.empty[Row]
}
@@ -53,7 +53,7 @@ case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with
def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
val ifExistsClause = if (ifExists) "IF EXISTS " else ""
hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
hiveContext.catalog.unregisterTable(None, tableName)
@@ -70,7 +70,7 @@ case class AddJar(path: String) extends LeafNode with Command {
override def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
hiveContext.runSqlHive(s"ADD JAR $path")
hiveContext.sparkContext.addJar(path)
Seq.empty[Row]
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 679efe082f2a0..3647bb1c4ce7d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -63,4 +63,10 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT key, value FROM test_ctas_123 ORDER BY key"),
sql("SELECT key, value FROM src ORDER BY key").collect().toSeq)
}
+
+ test("SPARK-3708 Backticks aren't handled correctly is aliases") {
+ checkAnswer(
+ sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"),
+ sql("SELECT `key` FROM src").collect().toSeq)
+ }
}