Skip to content

Commit

Permalink
Addresses comments
Browse files Browse the repository at this point in the history
- Migrates to the new DataFrame reader/writer API
- Merges HadoopTypeConverter into HiveInspectors
- Refactors test suites
  • Loading branch information
liancheng committed May 16, 2015
1 parent 21ada22 commit d4afeed
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 449 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def mode(saveMode: String): DataFrameWriter = {
saveMode.toLowerCase match {
this.mode = saveMode.toLowerCase match {
case "overwrite" => SaveMode.Overwrite
case "append" => SaveMode.Append
case "ignore" => SaveMode.Ignore
Expand Down
18 changes: 12 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,20 @@ private[sql] class DDLParser(
private[sql] object ResolvedDataSource {

private val builtinSources = Map(
"jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
"json" -> classOf[org.apache.spark.sql.json.DefaultSource],
"parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
"jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
"json" -> "org.apache.spark.sql.json.DefaultSource",
"parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
"orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
)

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
val loader = Utils.getContextOrSparkClassLoader

if (builtinSources.contains(provider)) {
return builtinSources(provider)
return loader.loadClass(builtinSources(provider))
}

val loader = Utils.getContextOrSparkClassLoader
try {
loader.loadClass(provider)
} catch {
Expand All @@ -208,7 +210,11 @@ private[sql] object ResolvedDataSource {
loader.loadClass(provider + ".DefaultSource")
} catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
sys.error("The ORC data source must be used with Hive support enabled.")
} else {
sys.error(s"Failed to load class for data source: $provider")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.hive

import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}

Expand Down Expand Up @@ -122,7 +122,7 @@ import scala.collection.JavaConversions._
* even a normal java object (POJO)
* UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet)
*
* 3) ConstantObjectInspector:
* 3) ConstantObjectInspector:
* Constant object inspector can be either primitive type or Complex type, and it bundles a
* constant value as its property, usually the value is created when the constant object inspector
* constructed.
Expand All @@ -133,7 +133,7 @@ import scala.collection.JavaConversions._
}
}}}
* Hive provides 3 built-in constant object inspectors:
* Primitive Object Inspectors:
* Primitive Object Inspectors:
* WritableConstantStringObjectInspector
* WritableConstantHiveVarcharObjectInspector
* WritableConstantHiveDecimalObjectInspector
Expand All @@ -147,9 +147,9 @@ import scala.collection.JavaConversions._
* WritableConstantByteObjectInspector
* WritableConstantBinaryObjectInspector
* WritableConstantDateObjectInspector
* Map Object Inspector:
* Map Object Inspector:
* StandardConstantMapObjectInspector
* List Object Inspector:
* List Object Inspector:
* StandardConstantListObjectInspector]]
* Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct
* Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union
Expand Down Expand Up @@ -250,9 +250,9 @@ private[hive] trait HiveInspectors {
poi.getWritableConstantValue.getHiveDecimal)
case poi: WritableConstantTimestampObjectInspector =>
poi.getWritableConstantValue.getTimestamp.clone()
case poi: WritableConstantIntObjectInspector =>
case poi: WritableConstantIntObjectInspector =>
poi.getWritableConstantValue.get()
case poi: WritableConstantDoubleObjectInspector =>
case poi: WritableConstantDoubleObjectInspector =>
poi.getWritableConstantValue.get()
case poi: WritableConstantBooleanObjectInspector =>
poi.getWritableConstantValue.get()
Expand Down Expand Up @@ -306,7 +306,7 @@ private[hive] trait HiveInspectors {
// In order to keep backward-compatible, we have to copy the
// bytes with old apis
val bw = x.getPrimitiveWritableObject(data)
val result = new Array[Byte](bw.getLength())
val result = new Array[Byte](bw.getLength())
System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
result
case x: DateObjectInspector if x.preferWritable() =>
Expand Down Expand Up @@ -394,6 +394,30 @@ private[hive] trait HiveInspectors {
identity[Any]
}

/**
* Builds specific unwrappers ahead of time according to object inspector
* types to avoid pattern matching and branching costs per row.
*/
def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit =
field.getFieldObjectInspector match {
case oi: BooleanObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value))
case oi: ByteObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value))
case oi: ShortObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value))
case oi: IntObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value))
case oi: LongObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value))
case oi: FloatObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value))
case oi: DoubleObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
case oi =>
(value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
}

/**
* Converts native catalyst types to the types expected by Hive
* @param a the value to be wrapped
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveShim}
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim}
import org.apache.spark.sql.sources.{Filter, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
Expand All @@ -50,6 +50,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
schema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
assert(
sqlContext.isInstanceOf[HiveContext],
"The ORC data source can only be used with HiveContext.")

val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition]))
OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext)
}
Expand All @@ -59,7 +63,7 @@ private[orc] class OrcOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter with SparkHadoopMapRedUtil {
extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors {

private val serializer = {
val table = new Properties()
Expand Down Expand Up @@ -89,7 +93,7 @@ private[orc] class OrcOutputWriter(

// Used to convert Catalyst values into Hadoop `Writable`s.
private val wrappers = structOI.getAllStructFieldRefs.map { ref =>
HadoopTypeConverter.wrappers(ref.getFieldObjectInspector)
wrapperFor(ref.getFieldObjectInspector)
}.toArray

// `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this
Expand Down Expand Up @@ -190,7 +194,10 @@ private[orc] case class OrcTableScan(
attributes: Seq[Attribute],
@transient relation: OrcRelation,
filters: Array[Filter],
inputPaths: Array[String]) extends Logging {
inputPaths: Array[String])
extends Logging
with HiveInspectors {

@transient private val sqlContext = relation.sqlContext

private def addColumnIds(
Expand All @@ -215,7 +222,7 @@ private[orc] case class OrcTableScan(
case (attr, ordinal) =>
soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal
}.unzip
val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs)
val unwrappers = fieldRefs.map(unwrapperFor)
// Map each tuple to a row object
iterator.map { value =>
val raw = deserializer.deserialize(value)
Expand All @@ -240,7 +247,7 @@ private[orc] case class OrcTableScan(
// Tries to push down filters if ORC filter push-down is enabled
if (sqlContext.conf.orcFilterPushDown) {
OrcFilters.createFilter(filters).foreach { f =>
conf.set(SARG_PUSHDOWN, f.toKryo)
conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo)
conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true)
}
}
Expand Down Expand Up @@ -276,3 +283,8 @@ private[orc] case class OrcTableScan(
}
}
}

private[orc] object OrcTableScan {
// This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public.
private[orc] val SARG_PUSHDOWN = "sarg.pushdown"
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
sparkContext.parallelize(data).toDF().saveAsOrcFile(file.getCanonicalPath)
sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
Expand All @@ -53,8 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
import org.apache.spark.sql.hive.orc.OrcContext
withOrcFile(data)(path => f(hiveContext.orcFile(path)))
withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path)))
}

/**
Expand All @@ -73,12 +72,12 @@ private[sql] trait OrcTest extends SQLTestUtils {

protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite)
data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath)
}

protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
df: DataFrame, path: File): Unit = {
df.save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite)
df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
sparkContext
.parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
.toDF("a", "b", "p1")
.saveAsOrcFile(partitionDir.toString)
.write
.format("orc")
.save(partitionDir.toString)
}

val dataSchemaWithPartition =
Expand Down
Loading

0 comments on commit d4afeed

Please sign in to comment.