Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1368][SQL] Optimized HiveTableScan #758

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ case class Aggregate(
*/
@transient
private[this] lazy val resultMap =
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap

/**
* Substituted version of aggregateExpressions expressions which are used to compute final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
package org.apache.spark.sql.hive.execution

import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.MetaStoreUtils
import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
import org.apache.hadoop.hive.serde2.Serializer
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred._

Expand All @@ -37,6 +40,7 @@ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
import org.apache.spark.{TaskContext, SparkException}
import org.apache.spark.util.MutablePair

/* Implicits */
import scala.collection.JavaConversions._
Expand Down Expand Up @@ -94,24 +98,63 @@ case class HiveTableScan(
(_: Any, partitionKeys: Array[String]) => {
val value = partitionKeys(ordinal)
val dataType = relation.partitionKeys(ordinal).dataType
castFromString(value, dataType)
unwrapHiveData(castFromString(value, dataType))
}
} else {
val ref = objectInspector.getAllStructFieldRefs
.find(_.getFieldName == a.name)
.getOrElse(sys.error(s"Can't find attribute $a"))
(row: Any, _: Array[String]) => {
val data = objectInspector.getStructFieldData(row, ref)
unwrapData(data, ref.getFieldObjectInspector)
unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
}
}
}
}

private def unwrapHiveData(value: Any) = value match {
case maybeNull: String if maybeNull.toLowerCase == "null" => null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to calculate the required unwrapping statically based on the object inspector instead of doing this match for every data item?

case varchar: HiveVarchar => varchar.getValue
case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue)
case other => other
}

private def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
}

private def addColumnMetadataToConf(hiveConf: HiveConf) {
// Specifies IDs and internal names of columns to be scanned.
val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer)
val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")

if (attributes.size == relation.output.size) {
ColumnProjectionUtils.setFullyReadColumns(hiveConf)
} else {
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
}

ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name))

// Specifies types and object inspectors of columns to be scanned.
val structOI = ObjectInspectorUtils
.getStandardObjectInspector(
relation.tableDesc.getDeserializer.getObjectInspector,
ObjectInspectorCopyOption.JAVA)
.asInstanceOf[StructObjectInspector]

val columnTypeNames = structOI
.getAllStructFieldRefs
.map(_.getFieldObjectInspector)
.map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName)
.mkString(",")

hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames)
hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames)
}

addColumnMetadataToConf(sc.hiveconf)

@transient
def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
Expand Down Expand Up @@ -143,20 +186,42 @@ case class HiveTableScan(
}

def execute() = {
inputRdd.map { row =>
val values = row match {
case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
attributeFunctions.map(_(deserializedRow, partitionKeys))
case deserializedRow: AnyRef =>
attributeFunctions.map(_(deserializedRow, Array.empty))
inputRdd.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val mutableRow = new GenericMutableRow(attributes.length)
val mutablePair = new MutablePair[Any, Array[String]]()
val buffered = iterator.buffered

// NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
// matching are avoided intentionally.
val rowsAndPartitionKeys = buffered.head match {
// With partition keys
case _: Array[Any] =>
buffered.map { case array: Array[Any] =>
val deserializedRow = array(0)
val partitionKeys = array(1).asInstanceOf[Array[String]]
mutablePair.update(deserializedRow, partitionKeys)
}

// Without partition keys
case _ =>
val emptyPartitionKeys = Array.empty[String]
buffered.map { deserializedRow =>
mutablePair.update(deserializedRow, emptyPartitionKeys)
}
}

rowsAndPartitionKeys.map { pair =>
var i = 0
while (i < attributes.length) {
mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
i += 1
}
mutableRow: Row
}
}
buildRow(values.map {
case n: String if n.toLowerCase == "null" => null
case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
BigDecimal(decimal.bigDecimalValue)
case other => other
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.hive.execution

import java.io._

import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}

import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.plans.logical.{ExplainCommand, NativeCommand}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.Sort
import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
import org.apache.spark.sql.hive.test.TestHive

/**
Expand Down Expand Up @@ -128,17 +129,19 @@ abstract class HiveComparisonTest
protected def prepareAnswer(
hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {

def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false
case PhysicalOperation(_, _, Sort(_, _)) => true
case _ => plan.children.iterator.map(isSorted).exists(_ == true)
}

val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
case _: ExplainCommand => answer
case _ =>
// TODO: Really we only care about the final total ordering here...
val isOrdered = hiveQuery.executedPlan.collect {
case s @ Sort(_, global, _) if global => s
}.nonEmpty
// If the query results aren't sorted, then sort them to ensure deterministic answers.
if (!isOrdered) answer.sorted else answer
case plan if isSorted(plan) => answer
case _ => answer.sorted
}
orderedAnswer.map(cleanPaths)
}
Expand All @@ -161,7 +164,7 @@ abstract class HiveComparisonTest
"minFileSize"
)
protected def nonDeterministicLine(line: String) =
nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_)
nonDeterministicLineIndicators.exists(line contains _)

/**
* Removes non-deterministic paths from `str` so cached answers will compare correctly.
Expand Down