From d4afeed86f2094df8cb6ba509f5a2da22c3bf02b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 22:19:06 +0800 Subject: [PATCH] Addresses comments - Migrates to the new DataFrame reader/writer API - Merges HadoopTypeConverter into HiveInspectors - Refactors test suites --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 18 +- .../spark/sql/hive/HiveInspectors.scala | 40 ++- .../sql/hive/orc/HadoopTypeConverter.scala | 61 ---- .../spark/sql/hive/orc/OrcRelation.scala | 24 +- .../apache/spark/sql/hive/orc/package.scala | 74 ----- .../spark/sql/hive/orc/NewOrcQuerySuite.scala | 9 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 4 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 32 +-- .../spark/sql/hive/orc/OrcQuerySuite.scala | 260 ++++++++---------- .../spark/sql/hive/orc/OrcSourceSuite.scala | 174 ++++-------- 11 files changed, 249 insertions(+), 449 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b1fc18ac3cb54..9f42f0f1f4398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -55,7 +55,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter = { - saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 37a569db311ea..a13ab74852ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -188,18 +188,20 @@ private[sql] class DDLParser( private[sql] object ResolvedDataSource { private val builtinSources = Map( - "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], - "json" -> classOf[org.apache.spark.sql.json.DefaultSource], - "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", + "json" -> "org.apache.spark.sql.json.DefaultSource", + "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", + "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" ) /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val loader = Utils.getContextOrSparkClassLoader + if (builtinSources.contains(provider)) { - return builtinSources(provider) + return loader.loadClass(builtinSources(provider)) } - val loader = Utils.getContextOrSparkClassLoader try { loader.loadClass(provider) } catch { @@ -208,7 +210,11 @@ private[sql] object ResolvedDataSource { loader.loadClass(provider + ".DefaultSource") } catch { case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + sys.error("The ORC data source must be used with Hive support enabled.") + } else { + sys.error(s"Failed to load class for data source: $provider") + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7c7666f6e4b7c..0a694c70e4e5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -122,7 +122,7 @@ import scala.collection.JavaConversions._ * even a normal java object (POJO) * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) * - * 3) ConstantObjectInspector: + * 3) ConstantObjectInspector: * Constant object inspector can be either primitive type or Complex type, and it bundles a * constant value as its property, usually the value is created when the constant object inspector * constructed. @@ -133,7 +133,7 @@ import scala.collection.JavaConversions._ } }}} * Hive provides 3 built-in constant object inspectors: - * Primitive Object Inspectors: + * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector * WritableConstantHiveDecimalObjectInspector @@ -147,9 +147,9 @@ import scala.collection.JavaConversions._ * WritableConstantByteObjectInspector * WritableConstantBinaryObjectInspector * WritableConstantDateObjectInspector - * Map Object Inspector: + * Map Object Inspector: * StandardConstantMapObjectInspector - * List Object Inspector: + * List Object Inspector: * StandardConstantListObjectInspector]] * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union @@ -250,9 +250,9 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => poi.getWritableConstantValue.getTimestamp.clone() - case poi: WritableConstantIntObjectInspector => + case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() - case poi: WritableConstantDoubleObjectInspector => + case poi: WritableConstantDoubleObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantBooleanObjectInspector => poi.getWritableConstantValue.get() @@ -306,7 +306,7 @@ private[hive] trait HiveInspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) + val result = new Array[Byte](bw.getLength()) System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => @@ -394,6 +394,30 @@ private[hive] trait HiveInspectors { identity[Any] } + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + field.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala deleted file mode 100644 index b5b5e56079cc3..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ - -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.hive.HiveInspectors - -/** - * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use - * this class. - * - */ -private[hive] object HadoopTypeConverter extends HiveInspectors { - /** - * Builds specific unwrappers ahead of time according to object inspector - * types to avoid pattern matching and branching costs per row. - */ - def unwrappers(fieldRefs: Seq[StructField]): Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { - _.getFieldObjectInspector match { - case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) - case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) - case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) - case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) - case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) - case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) - case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) - case oi => - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) - } - } - - /** - * Wraps with Hive types based on object inspector. - */ - def wrappers(oi: ObjectInspector): Any => Any = wrapperFor(oi) -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 3e3c8a9e619d5..9708199f07349 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} @@ -50,6 +50,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { schema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { + assert( + sqlContext.isInstanceOf[HiveContext], + "The ORC data source can only be used with HiveContext.") + val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext) } @@ -59,7 +63,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil { + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -89,7 +93,7 @@ private[orc] class OrcOutputWriter( // Used to convert Catalyst values into Hadoop `Writable`s. private val wrappers = structOI.getAllStructFieldRefs.map { ref => - HadoopTypeConverter.wrappers(ref.getFieldObjectInspector) + wrapperFor(ref.getFieldObjectInspector) }.toArray // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this @@ -190,7 +194,10 @@ private[orc] case class OrcTableScan( attributes: Seq[Attribute], @transient relation: OrcRelation, filters: Array[Filter], - inputPaths: Array[String]) extends Logging { + inputPaths: Array[String]) + extends Logging + with HiveInspectors { + @transient private val sqlContext = relation.sqlContext private def addColumnIds( @@ -215,7 +222,7 @@ private[orc] case class OrcTableScan( case (attr, ordinal) => soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal }.unzip - val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + val unwrappers = fieldRefs.map(unwrapperFor) // Map each tuple to a row object iterator.map { value => val raw = deserializer.deserialize(value) @@ -240,7 +247,7 @@ private[orc] case class OrcTableScan( // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { OrcFilters.createFilter(filters).foreach { f => - conf.set(SARG_PUSHDOWN, f.toKryo) + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -276,3 +283,8 @@ private[orc] case class OrcTableScan( } } } + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala deleted file mode 100644 index ad0f65442b914..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{DataFrame, SaveMode} - -package object orc { - /** - * ::Experimental:: - * - * Extra ORC file loading functionality on [[HiveContext]] through implicit conversion. - * - * @since 1.4.0 - */ - @Experimental - implicit class OrcContext(sqlContext: HiveContext) { - /** - * ::Experimental:: - * - * Loads specified Parquet files, returning the result as a [[DataFrame]]. - * - * @since 1.4.0 - */ - @Experimental - @scala.annotation.varargs - def orcFile(paths: String*): DataFrame = { - val orcRelation = OrcRelation(paths.toArray, Map.empty)(sqlContext) - sqlContext.baseRelationToDataFrame(orcRelation) - } - } - - /** - * ::Experimental:: - * - * Extra ORC file writing functionality on [[DataFrame]] through implicit conversion - * - * @since 1.4.0 - */ - @Experimental - implicit class OrcDataFrame(dataFrame: DataFrame) { - /** - * ::Experimental:: - * - * Saves the contents of this [[DataFrame]] as an ORC file, preserving the schema. Files that - * are written out using this method can be read back in as a [[DataFrame]] using - * [[OrcContext.orcFile()]]. - * - * @since 1.4.0 - */ - @Experimental - def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { - dataFrame.save(path, source = classOf[DefaultSource].getCanonicalName, mode) - } - } - - // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. - private[orc] val SARG_PUSHDOWN = "sarg.pushdown" -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala index 7e326de1335e0..ad2fad05188de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala @@ -41,7 +41,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().saveAsOrcFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -53,8 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - import org.apache.spark.sql.hive.orc.OrcContext - withOrcFile(data)(path => f(hiveContext.orcFile(path))) + withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) } /** @@ -73,12 +72,12 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 90812b03fd2e6..080af5bb23c16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -40,7 +40,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") - .saveAsOrcFile(partitionDir.toString) + .write + .format("orc") + .save(partitionDir.toString) } val dataSchemaWithPartition = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 55d8b8c71d9ef..88c99e35260d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -48,13 +48,13 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().saveAsOrcFile(path.getCanonicalPath) + data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) } def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.saveAsOrcFile(path.getCanonicalPath) + df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { @@ -89,7 +89,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -136,7 +136,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -185,13 +185,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val orcRelation = load( - "org.apache.spark.sql.hive.orc.DefaultSource", - Map( - "path" -> base.getCanonicalPath, - ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) - - orcRelation.registerTempTable("t") + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( @@ -230,13 +228,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val orcRelation = load( - "org.apache.spark.sql.hive.orc.DefaultSource", - Map( - "path" -> base.getCanonicalPath, - ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) - - orcRelation.registerTempTable("t") + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index abc4c92d91da8..338ed7add1995 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,43 +21,13 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} - -case class TestRDDEntry(key: Int, value: String) - -case class NullReflectData( - intField: java.lang.Integer, - longField: java.lang.Long, - floatField: java.lang.Float, - doubleField: java.lang.Double, - booleanField: java.lang.Boolean) - -case class OptionalReflectData( - intField: Option[Int], - longField: Option[Long], - floatField: Option[Float], - doubleField: Option[Double], - booleanField: Option[Boolean]) - -case class Nested(i: Int, s: String) - -case class Data(array: Seq[Int], nested: Nested) - -case class AllDataTypes( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -72,7 +42,7 @@ case class AllDataTypesWithNonPrimitiveType( arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], - data: Data) + data: (Seq[Int], (Int, String))) case class BinaryData(binaryData: Array[Byte]) @@ -80,7 +50,10 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { +class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest { + override val sqlContext = TestHive + + import TestHive.read def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -88,157 +61,146 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { tempFile } - test("Read/Write All Types") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => - AllDataTypes(s"$x", x, x.toLong, x.toFloat,x.toDouble, x.toShort, x.toByte, x % 2 == 0)) - data.toDF().saveAsOrcFile(tempDir) + test("Read/write All Types") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) + } + + withOrcFile(data) { file => checkAnswer( - TestHive.orcFile(tempDir), - data.toDF().collect().toSeq) - Utils.deleteRecursively(new File(tempDir)) + read.format("orc").load(file), + data.toDF().collect()) } + } - test("read/write binary data") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil) - .toDF().saveAsOrcFile(tempDir) - TestHive.orcFile(tempDir) - .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) - .collect().toSeq == Seq("test") - Utils.deleteRecursively(new File(tempDir)) + test("Read/write binary data") { + withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + assert(new String(bytes, "utf8") === "test") } + } - test("Read/Write All Types with non-primitive type") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = 0 to 255 - val data = sparkContext.parallelize(range).map { x => - AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - 0 until x, - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - Data(0 until x, Nested(x, s"$x"))) - } - data.toDF().saveAsOrcFile(tempDir) + test("Read/write all types with non-primitive type") { + val data = (0 to 255).map { i => + AllDataTypesWithNonPrimitiveType( + s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + 0 until i, + (0 until i).map(Option(_).filter(_ % 3 == 0)), + (0 until i).map(i => i -> i.toLong).toMap, + (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None), + (0 until i, (i, s"$i"))) + } + withOrcFile(data) { file => checkAnswer( - TestHive.orcFile(tempDir), - data.toDF().collect().toSeq) - Utils.deleteRecursively(new File(tempDir)) + read.format("orc").load(file), + data.toDF().collect()) } + } - test("Creating case class RDD table") { - sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - .toDF().registerTempTable("tmp") - val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) - var counter = 1 - rdd.foreach { - // '===' does not like string comparison? - row: Row => { - assert(row.getString(1).equals(s"val_$counter"), - s"row $counter value ${row.getString(1)} does not match val_$counter") - counter = counter + 1 - } - } + test("Creating case class RDD table") { + val data = (1 to 100).map(i => (i, s"val_$i")) + sparkContext.parallelize(data).toDF().registerTempTable("t") + withTempTable("t") { + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) } + } - test("Simple selection form orc table") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val data = sparkContext.parallelize((1 to 10)) - .map(i => Person(s"name_$i", i, (0 until 2).map{ m=> - Contact(s"contact_$m", s"phone_$m") })) - data.toDF().saveAsOrcFile(tempDir) - val f = TestHive.orcFile(tempDir) - f.registerTempTable("tmp") + test("Simple selection form ORC table") { + val data = (1 to 10).map { i => + Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") }) + } + withOrcTable(data, "t") { // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = leaf-0 - var rdd = sql("SELECT name FROM tmp where age <= 5") - assert(rdd.count() == 5) + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = (not leaf-0) - rdd = sql("SELECT name, contacts FROM tmp where age > 5") - assert(rdd.count() == 5) - var contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 10) + assertResult(10) { + sql("SELECT name, contacts FROM t where age > 5") + .flatMap(_.getAs[Seq[_]]("contacts")) + .count() + } // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // leaf-1 = (LESS_THAN age 8) // expr = (and (not leaf-0) leaf-1) - rdd = sql("SELECT name, contacts FROM tmp where age > 5 and age < 8") - assert(rdd.count() == 2) - contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 4) + { + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + assert(df.count() === 2) + assertResult(4) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } // ppd: // leaf-0 = (LESS_THAN age 2) // leaf-1 = (LESS_THAN_EQUALS age 8) // expr = (or leaf-0 (not leaf-1)) - rdd = sql("SELECT name, contacts FROM tmp where age < 2 or age > 8") - assert(rdd.count() == 3) - contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 6) + { + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + assert(df.count() === 3) + assertResult(6) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + } + } + test("save and load case class RDD with `None`s as orc") { + val data = ( + None: Option[Int], + None: Option[Long], + None: Option[Float], + None: Option[Double], + None: Option[Boolean] + ) :: Nil - Utils.deleteRecursively(new File(tempDir)) + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + Row(Seq.fill(5)(null): _*)) } + } - test("save and load case class RDD with Nones as orc") { - val data = OptionalReflectData(None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - val tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - val readFile = TestHive.orcFile(tempDir) - val rdd_saved = readFile.collect() - assert(rdd_saved(0).toSeq === Seq.fill(5)(null)) - Utils.deleteRecursively(new File(tempDir)) + // We only support zlib in Hive 0.12.0 now + test("Default compression options for writing to an ORC file") { + withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => + assertResult(CompressionKind.ZLIB) { + OrcFileOperator.getFileReader(file).getCompression + } } + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { + val data = (1 to 100).map(i => (i, s"val_$i")) + val conf = sparkContext.hadoopConfiguration - // We only support zlib in hive0.12.0 now - test("Default Compression options for writing to an Orcfile") { - // TODO: support other compress codec - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize(1 to 100) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.toDF().saveAsOrcFile(tempDir) - val actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.ZLIB) - Utils.deleteRecursively(new File(tempDir)) + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") + withOrcFile(data) { file => + assertResult(CompressionKind.SNAPPY) { + OrcFileOperator.getFileReader(file).getCompression + } } - // Following codec is supported in hive-0.13.1, ignore it now - ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { - val conf = TestHive.sparkContext.hadoopConfiguration - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "SNAPPY") - var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize(1 to 100) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.toDF().saveAsOrcFile(tempDir) - var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.SNAPPY) - Utils.deleteRecursively(new File(tempDir)) - - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "NONE") - tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.NONE) - Utils.deleteRecursively(new File(tempDir)) - - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "LZO") - tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.LZO) - Utils.deleteRecursively(new File(tempDir)) + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") + withOrcFile(data) { file => + assertResult(CompressionKind.NONE) { + OrcFileOperator.getFileReader(file).getCompression + } } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") + withOrcFile(data) { file => + assertResult(CompressionKind.LZO) { + OrcFileOperator.getFileReader(file).getCompression + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index f86750bcfb6d4..82e08caf46457 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.hive.orc import java.io.File + import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, Row} case class OrcData(intField: Int, stringField: String) @@ -42,25 +43,25 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir.mkdir() import org.apache.spark.sql.hive.test.TestHive.implicits._ - (sparkContext + sparkContext .makeRDD(1 to 10) - .map(i => OrcData(i, s"part-$i"))) - .toDF.registerTempTable(s"orc_temp_table") - - sql(s""" - create external table normal_orc - ( - intField INT, - stringField STRING - ) - STORED AS orc - location '${orcTableDir.getCanonicalPath}' - """) + .map(i => OrcData(i, s"part-$i")) + .toDF() + .registerTempTable(s"orc_temp_table") sql( - s"""insert into table normal_orc - select intField, stringField from orc_temp_table""") + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.getCanonicalPath}' + """.stripMargin) + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) } override def afterAll(): Unit = { @@ -73,41 +74,15 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source where intField > 5"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), - Row(1, "part-1") :: - Row(1, "part-2") :: - Row(1, "part-3") :: - Row(1, "part-4") :: - Row(1, "part-5") :: - Row(1, "part-6") :: - Row(1, "part-7") :: - Row(1, "part-8") :: - Row(1, "part-9") :: - Row(1, "part-10") :: Nil - ) - + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { @@ -115,76 +90,36 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT * FROM normal_orc_source where intField > 5"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), - Row(1, "part-1") :: - Row(1, "part-2") :: - Row(1, "part-3") :: - Row(1, "part-4") :: - Row(1, "part-5") :: - Row(1, "part-6") :: - Row(1, "part-7") :: - Row(1, "part-8") :: - Row(1, "part-9") :: - Row(1, "part-10") :: Nil - ) - + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) } test("appending insert") { - sql("insert into table normal_orc_source select * from orc_temp_table where intField > 5") + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + checkAnswer( - sql("select * from normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(9, "part-9") :: - Row(10, "part-10") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_source"), + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + Seq.fill(2)(Row(i, s"part-$i")) + }) } test("overwrite insert") { - sql("insert overwrite table normal_orc_as_source select * " + - "from orc_temp_table where intField > 5") + sql( + """INSERT OVERWRITE TABLE normal_orc_as_source + |SELECT * FROM orc_temp_table WHERE intField > 5 + """.stripMargin) + checkAnswer( - sql("select * from normal_orc_as_source"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_as_source"), + (6 to 10).map(i => Row(i, s"part-$i"))) } } @@ -192,21 +127,20 @@ class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( s""" - create temporary table normal_orc_source - USING org.apache.spark.sql.hive.orc - OPTIONS ( - path '${new File(orcTableDir.getAbsolutePath).getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table normal_orc_as_source - USING org.apache.spark.sql.hive.orc - OPTIONS ( - path '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' - ) - as select * from orc_temp_table - """) + sql( + s"""CREATE TEMPORARY TABLE normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) } }