diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3ced11a5e6c11..2e7abac1f1bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -103,6 +103,19 @@ private[sql] trait CacheManager { cachedData.remove(dataIndex) } + /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + private[sql] def tryUncacheQuery( + query: SchemaRDD, + blocking: Boolean = true): Boolean = writeLock { + val planToCache = query.queryExecution.analyzed + val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) + val found = dataIndex >= 0 + if (found) { + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + found + } /** Optionally returns cached data for the given SchemaRDD */ private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f45aec00f970b..9e61d18f7e926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -278,6 +278,19 @@ class SQLContext(@transient val sparkContext: SparkContext) catalog.registerTable(None, tableName, rdd.queryExecution.logical) } + /** + * Drops the temporary table with the given table name in the catalog. If the table has been + * cached/persisted before, it's also unpersisted. + * + * @param tableName the name of the table to be unregistered. + * + * @group userf + */ + def dropTempTable(tableName: String): Unit = { + tryUncacheQuery(table(tableName)) + catalog.unregisterTable(None, tableName) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 5bb6f6c85d801..0f2dcdcacf0ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -73,16 +73,18 @@ private[sql] object JsonRDD extends Logging { def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { case ArrayType(elementType, _) => { def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => false + case s: StructType => true case ArrayType(t1, _) => hasInnerStruct(t1) - case o => true + case o => false } - hasInnerStruct(elementType) + // Check if this array has inner struct. + !hasInnerStruct(elementType) } case struct: StructType => false case _ => true @@ -90,8 +92,11 @@ private[sql] object JsonRDD extends Logging { }.map { a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) } + val topLevelFieldNameSet = topLevelFields.map(_.name) - val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { + case (name, _) => !topLevelFieldNameSet.contains(name) + }.map { case (name, fields) => { val nestedFields = fields.map(_.tail) val structType = makeStruct(nestedFields, prefix :+ name) @@ -354,7 +359,8 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toString(value)}""") + val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) + builder.append(s"""\"${key}\":${stringValue}""") } builder.append("}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 44a2961b27eda..765fa82776341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -231,4 +231,24 @@ class CachedTableSuite extends QueryTest { assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } + + test("Drops temporary table") { + testData.select('key).registerTempTable("t1") + table("t1") + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + } + + test("Drops cached temporary table") { + testData.select('key).registerTempTable("t1") + testData.select('key).registerTempTable("t2") + cacheTable("t1") + + assert(isCached("t1")) + assert(isCached("t2")) + + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + assert(!isCached("t2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 99da5012349d0..b329d3df5a9dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -483,7 +483,8 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: Nil) + StructField("field", LongType, true) :: Nil), false), true) :: + StructField("array3", ArrayType(StringType, false), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -492,12 +493,14 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":str}"""), Seq(Seq(214748364700L), Seq(1))) :: Nil + """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: + Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Seq(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( - sql("select array1[0] + 1 from jsonTable"), + sql("select array1[0] + 1 from jsonTable where array1 is not null"), 2 ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index c204162dd2fc1..e5773a55875bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -57,7 +57,9 @@ object TestJsonData { val arrayElementTypeConflict = TestSQLContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], - "array2": [{"field":214748364700}, {"field":1}]}""" :: Nil) + "array2": [{"field":214748364700}, {"field":1}]}""" :: + """{"array3": [{"field":"str"}, {"field":1}]}""" :: + """{"array3": [1, 2, 3]}""" :: Nil) val missingFields = TestSQLContext.sparkContext.parallelize( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 1e2bf5cc4b0b4..58815daa82276 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._ @@ -114,6 +114,51 @@ private[hive] trait HiveInspectors { unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) } + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct + } + + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) + + case moi: MapObjectInspector => + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 92bc1c6625892..74b4e7aaa47a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.HiveVarchar import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils @@ -52,7 +52,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode with Command { + extends UnaryNode with Command with HiveInspectors { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -68,46 +68,6 @@ case class InsertIntoHiveTable( def output = child.output - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ - protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) - - case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) - (o: Any) => { - val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) - } - struct - } - - case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - - case moi: MapObjectInspector => - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) - (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) - }) - - case _ => - identity[Any] - } - def saveAsHiveFile( rdd: RDD[Row], valueClass: Class[_],