Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into udts
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Nov 3, 2014
2 parents 15c10a6 + 9081b9f commit e369b91
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 52 deletions.
13 changes: 13 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ private[sql] trait CacheManager {
cachedData.remove(dataIndex)
}

/** Tries to remove the data for the given SchemaRDD from the cache if it's cached */
private[sql] def tryUncacheQuery(
query: SchemaRDD,
blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
if (found) {
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData.remove(dataIndex)
}
found
}

/** Optionally returns cached data for the given SchemaRDD */
private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
catalog.registerTable(None, tableName, rdd.queryExecution.logical)
}

/**
* Drops the temporary table with the given table name in the catalog. If the table has been
* cached/persisted before, it's also unpersisted.
*
* @param tableName the name of the table to be unregistered.
*
* @group userf
*/
def dropTempTable(tableName: String): Unit = {
tryUncacheQuery(table(tableName))
catalog.unregisterTable(None, tableName)
}

/**
* Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
Expand Down
16 changes: 11 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,30 @@ private[sql] object JsonRDD extends Logging {

def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = {
val (topLevel, structLike) = values.partition(_.size == 1)

val topLevelFields = topLevel.filter {
name => resolved.get(prefix ++ name).get match {
case ArrayType(elementType, _) => {
def hasInnerStruct(t: DataType): Boolean = t match {
case s: StructType => false
case s: StructType => true
case ArrayType(t1, _) => hasInnerStruct(t1)
case o => true
case o => false
}

hasInnerStruct(elementType)
// Check if this array has inner struct.
!hasInnerStruct(elementType)
}
case struct: StructType => false
case _ => true
}
}.map {
a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true)
}
val topLevelFieldNameSet = topLevelFields.map(_.name)

val structFields: Seq[StructField] = structLike.groupBy(_(0)).map {
val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter {
case (name, _) => !topLevelFieldNameSet.contains(name)
}.map {
case (name, fields) => {
val nestedFields = fields.map(_.tail)
val structType = makeStruct(nestedFields, prefix :+ name)
Expand Down Expand Up @@ -354,7 +359,8 @@ private[sql] object JsonRDD extends Logging {
case (key, value) =>
if (count > 0) builder.append(",")
count += 1
builder.append(s"""\"${key}\":${toString(value)}""")
val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value)
builder.append(s"""\"${key}\":${stringValue}""")
}
builder.append("}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,24 @@ class CachedTableSuite extends QueryTest {
assert(cached.statistics.sizeInBytes === actualSizeInBytes)
}
}

test("Drops temporary table") {
testData.select('key).registerTempTable("t1")
table("t1")
dropTempTable("t1")
assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found"))
}

test("Drops cached temporary table") {
testData.select('key).registerTempTable("t1")
testData.select('key).registerTempTable("t2")
cacheTable("t1")

assert(isCached("t1"))
assert(isCached("t2"))

dropTempTable("t1")
assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found"))
assert(!isCached("t2"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
StructField("array2", ArrayType(StructType(
StructField("field", LongType, true) :: Nil), false), true) :: Nil)
StructField("field", LongType, true) :: Nil), false), true) ::
StructField("array3", ArrayType(StringType, false), true) :: Nil)

assert(expectedSchema === jsonSchemaRDD.schema)

Expand All @@ -492,12 +493,14 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]",
"""{"field":str}"""), Seq(Seq(214748364700L), Seq(1))) :: Nil
"""{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) ::
Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) ::
Seq(null, null, Seq("1", "2", "3")) :: Nil
)

// Treat an element as a number.
checkAnswer(
sql("select array1[0] + 1 from jsonTable"),
sql("select array1[0] + 1 from jsonTable where array1 is not null"),
2
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ object TestJsonData {
val arrayElementTypeConflict =
TestSQLContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" :: Nil)
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)

val missingFields =
TestSQLContext.sparkContext.parallelize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hive

import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._
Expand Down Expand Up @@ -114,6 +114,51 @@ private[hive] trait HiveInspectors {
unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
}


/**
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
/**
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
(o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)

case _: JavaHiveDecimalObjectInspector =>
(o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying())

case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
(o: Any) => {
val struct = soi.create()
(soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
}

case loi: ListObjectInspector =>
val wrapper = wrapperFor(loi.getListElementObjectInspector)
(o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper))

case moi: MapObjectInspector =>
// The Predef.Map is scala.collection.immutable.Map.
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map

val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
(o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
keyWrapper(key) -> valueWrapper(value)
})

case _ =>
identity[Any]
}

/**
* Converts native catalyst types to the types expected by Hive
* @param a the value to be wrapped
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution

import scala.collection.JavaConversions._

import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.common.`type`.HiveVarchar
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.metastore.MetaStoreUtils
Expand Down Expand Up @@ -52,7 +52,7 @@ case class InsertIntoHiveTable(
child: SparkPlan,
overwrite: Boolean)
(@transient sc: HiveContext)
extends UnaryNode with Command {
extends UnaryNode with Command with HiveInspectors {

@transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass
@transient private lazy val hiveContext = new Context(sc.hiveconf)
Expand All @@ -68,46 +68,6 @@ case class InsertIntoHiveTable(

def output = child.output

/**
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
(o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)

case _: JavaHiveDecimalObjectInspector =>
(o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying())

case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
(o: Any) => {
val struct = soi.create()
(soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
}

case loi: ListObjectInspector =>
val wrapper = wrapperFor(loi.getListElementObjectInspector)
(o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper))

case moi: MapObjectInspector =>
// The Predef.Map is scala.collection.immutable.Map.
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map

val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
(o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
keyWrapper(key) -> valueWrapper(value)
})

case _ =>
identity[Any]
}

def saveAsHiveFile(
rdd: RDD[Row],
valueClass: Class[_],
Expand Down

0 comments on commit e369b91

Please sign in to comment.