diff --git a/build.sbt b/build.sbt index bc74bd56ff186..39e0d9805b088 100644 --- a/build.sbt +++ b/build.sbt @@ -12,7 +12,10 @@ scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked") resolvers += "Local Maven Repository" at "file://"+Path.userHome.absolutePath+"/.m2/repository" -libraryDependencies += "org.apache.spark" %% "spark-core" % "0.9.0-incubating-SNAPSHOT" +// TODO: Remove when Spark 0.9.0 is released for real. +resolvers += "SparkStaging" at "https://repository.apache.org/content/repositories/orgapachespark-1006/" + +libraryDependencies += "org.apache.spark" %% "spark-core" % "0.9.0-incubating" libraryDependencies += "catalyst" % "hive-golden" % "4" % "test" from "http://repository-databricks.forge.cloudbees.com/snapshot/catalystGolden4.jar" diff --git a/src/main/scala/catalyst/analysis/Analyzer.scala b/src/main/scala/catalyst/analysis/Analyzer.scala index 58f69650057f2..0ba054d7b3041 100644 --- a/src/main/scala/catalyst/analysis/Analyzer.scala +++ b/src/main/scala/catalyst/analysis/Analyzer.scala @@ -34,6 +34,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: + PreInsertionCasts :: typeCoercionRules :_*) ) @@ -106,7 +107,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case agg: AggregateExpression => return true case _ => }) - return false + false } } @@ -141,4 +142,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool protected def containsStar(exprs: Seq[NamedExpression]): Boolean = exprs.collect { case _: Star => true }.nonEmpty } + + /** + * Casts input data to correct data types according to table definition before inserting into + * that table. + */ + object PreInsertionCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + // Wait until children are resolved + case p: LogicalPlan if !p.childrenResolved => p + + case p @ InsertIntoTable(table, _, child) => + val childOutputDataTypes = child.output.map(_.dataType) + val tableOutputDataTypes = table.output.map(_.dataType) + + if (childOutputDataTypes sameElements tableOutputDataTypes) { + p + } else { + // Only do the casting when child output data types differ from table output data types. + val castedChildOutput = child.output.zip(table.output).map { + case (l, r) if l.dataType != r.dataType => Alias(Cast(l, r.dataType), l.name)() + case (l, _) => l + } + + p.copy(child = Project(castedChildOutput, child)) + } + } + } } diff --git a/src/main/scala/catalyst/execution/SharkInstance.scala b/src/main/scala/catalyst/execution/SharkInstance.scala index 256f07efe02ad..77599ba8e059f 100644 --- a/src/main/scala/catalyst/execution/SharkInstance.scala +++ b/src/main/scala/catalyst/execution/SharkInstance.scala @@ -73,6 +73,7 @@ abstract class SharkInstance extends Logging { val sc = self.sc val strategies = SparkEquiInnerJoin :: + PartitionPrunings :: HiveTableScans :: DataSinks :: BasicOperators :: @@ -81,7 +82,8 @@ abstract class SharkInstance extends Logging { } object PrepareForExecution extends RuleExecutor[SharkPlan] { - val batches = Batch("Prepare Expressions", Once, expressions.BindReferences) :: Nil + val batches = + Batch("Prepare Expressions", Once, new expressions.BindReferences[SharkPlan]) :: Nil } class SharkSqlQuery(sql: String) extends SharkQuery { diff --git a/src/main/scala/catalyst/execution/TableReader.scala b/src/main/scala/catalyst/execution/TableReader.scala index 0f48788812301..040d57f95fe9e 100644 --- a/src/main/scala/catalyst/execution/TableReader.scala +++ b/src/main/scala/catalyst/execution/TableReader.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.{HadoopRDD, UnionRDD, EmptyRDD, RDD} * type of table storage: HeapTableReader for Shark tables in Spark's block manager, * TachyonTableReader for tables in Tachyon, and HadoopTableReader for Hive tables in a filesystem. */ -sealed trait TableReader { +private[catalyst] sealed trait TableReader { def makeRDDForTable(hiveTable: HiveTable): RDD[_] @@ -34,7 +34,7 @@ sealed trait TableReader { * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the * data warehouse directory. */ -class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf: HiveConf) +private[catalyst] class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf: HiveConf) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -93,11 +93,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf deserializer.initialize(hconf, tableDesc.getProperties) // Deserialize each Writable to get the row value. - iter.map { value => - value match { - case v: Writable => deserializer.deserialize(v) - case _ => throw new RuntimeException("Failed to match " + value.toString) - } + iter.map { + case v: Writable => deserializer.deserialize(v) + case value => + sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") } } deserializedHadoopRDD @@ -130,8 +129,8 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] // Get partition field info - val partSpec = partDesc.getPartSpec() - val partProps = partDesc.getProperties() + val partSpec = partDesc.getPartSpec + val partProps = partDesc.getProperties val partColsDelimited: String = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) // Partitioning columns are delimited by "/" @@ -156,7 +155,7 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf iter.map { value => val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - val deserializedRow = deserializer.deserialize(value) // LazyStruct + val deserializedRow = deserializer.deserialize(value) rowWithPartArr.update(0, deserializedRow) rowWithPartArr.update(1, partValues) rowWithPartArr.asInstanceOf[Object] @@ -177,11 +176,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf */ private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { filterOpt match { - case Some(filter) => { + case Some(filter) => val fs = path.getFileSystem(_localHConf) val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) filteredFiles.mkString(",") - } case None => path.toString } } @@ -212,7 +210,7 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient _localHConf } -object HadoopTableReader { +private[catalyst] object HadoopTableReader { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to diff --git a/src/main/scala/catalyst/execution/TestShark.scala b/src/main/scala/catalyst/execution/TestShark.scala index f297c14ca805e..7744c8f0fbc29 100644 --- a/src/main/scala/catalyst/execution/TestShark.scala +++ b/src/main/scala/catalyst/execution/TestShark.scala @@ -11,6 +11,10 @@ import scala.language.implicitConversions import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.io.avro.{AvroContainerOutputFormat, AvroContainerInputFormat} +import org.apache.hadoop.hive.serde2.avro.AvroSerDe +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.RegexSerDe import analysis._ import plans.logical.LogicalPlan @@ -164,11 +168,20 @@ object TestShark extends SharkInstance { "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd), TestTable("srcpart", () => { runSqlHive("CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") - Seq("2008-04-08", "2008-04-09").foreach { ds => - Seq("11", "12").foreach { hr => - val partSpec = Map("ds" -> ds, "hr" -> hr) - runSqlHive(s"LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/kv1.txt' OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')") - } + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/kv1.txt' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("srcpart1", () => { + runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/kv1.txt' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin) } }), TestTable("src_thrift", () => { @@ -194,7 +207,51 @@ object TestShark extends SharkInstance { catalog.client.createTable(srcThrift) runSqlHive(s"LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/complex.seq' INTO TABLE src_thrift") - }) + }), + TestTable("serdeins", + s"""CREATE TABLE serdeins (key INT, value STRING) + |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ('field.delim'='\\t') + """.stripMargin.cmd, + "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), + TestTable("sales", + s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/sales.txt' INTO TABLE sales".cmd), + TestTable("episodes", + s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${hiveDevHome.getCanonicalPath}/data/files/episodes.avro' INTO TABLE episodes".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) @@ -203,11 +260,12 @@ object TestShark extends SharkInstance { def loadTestTable(name: String) { if (!(loadedTables contains name)) { + // Marks the table as loaded first to prevent infite mutually recursive table loading. + loadedTables += name logger.info(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) - loadedTables += name } } diff --git a/src/main/scala/catalyst/execution/hiveOperators.scala b/src/main/scala/catalyst/execution/hiveOperators.scala index efdaa94950ac7..68fe99b5dc9cd 100644 --- a/src/main/scala/catalyst/execution/hiveOperators.scala +++ b/src/main/scala/catalyst/execution/hiveOperators.scala @@ -1,21 +1,54 @@ package catalyst package execution +import java.nio.file.Files + import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.`type`.HiveVarchar import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils -import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} +import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} import org.apache.hadoop.hive.serde2.AbstractSerDe -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, StructObjectInspector} -import org.apache.hadoop.hive.serde2.`lazy`.LazyStruct +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapred.JobConf +import org.apache.spark.SparkContext._ -import expressions.Attribute -import util._ +import catalyst.expressions._ +import catalyst.types.{BooleanType, DataType} /* Implicits */ import scala.collection.JavaConversions._ -case class HiveTableScan(attributes: Seq[Attribute], relation: MetastoreRelation) extends LeafNode { +/** + * The Hive table scan operator. Column and partition pruning are both handled. + * + * @constructor + * @param attributes Attributes to be fetched from the Hive table. + * @param relation The Hive table be be scanned. + * @param partitionPruningPred An optional partition pruning predicate for partitioned table. + */ +case class HiveTableScan( + attributes: Seq[Attribute], + relation: MetastoreRelation, + partitionPruningPred: Option[Expression]) + extends LeafNode { + + require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + "Partition pruning predicates only supported for partitioned tables.") + + // Bind all partition key attribute references in the partition pruning predicate for later + // evaluation. + private val boundPruningPred = partitionPruningPred.map { pred => + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + + BindReferences.bindReference(pred, Seq(relation.partitionKeys)) + } + @transient val hadoopReader = new HadoopTableReader(relation.tableDesc, SharkContext.hiveconf) @@ -28,21 +61,25 @@ case class HiveTableScan(attributes: Seq[Attribute], relation: MetastoreRelation relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] /** - * Functions that extract the requested attributes from the hive output. + * Functions that extract the requested attributes from the hive output. Partitioned values are + * casted from string to its declared data type. */ @transient - protected lazy val attributeFunctions: Seq[(LazyStruct, Array[String]) => AnyRef] = { + protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { attributes.map { a => - if (relation.partitionKeys.contains(a)) { - val ordinal = relation.partitionKeys.indexOf(a) - (struct: LazyStruct, partitionKeys: Array[String]) => partitionKeys(ordinal) + val ordinal = relation.partitionKeys.indexOf(a) + if (ordinal >= 0) { + (_: Any, partitionKeys: Array[String]) => { + val value = partitionKeys(ordinal) + val dataType = relation.partitionKeys(ordinal).dataType + castFromString(value, dataType) + } } else { val ref = objectInspector.getAllStructFieldRefs .find(_.getFieldName == a.name) .getOrElse(sys.error(s"Can't find attribute $a")) - - (struct: LazyStruct, _: Array[String]) => { - val data = objectInspector.getStructFieldData(struct, ref) + (row: Any, _: Array[String]) => { + val data = objectInspector.getStructFieldData(row, ref) val inspector = ref.getFieldObjectInspector.asInstanceOf[PrimitiveObjectInspector] inspector.getPrimitiveJavaObject(data) } @@ -50,24 +87,50 @@ case class HiveTableScan(attributes: Seq[Attribute], relation: MetastoreRelation } } + private def castFromString(value: String, dataType: DataType) = { + Evaluate(Cast(Literal(value), dataType), Nil) + } + @transient def inputRdd = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable(relation.hiveQlPartitions) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + } + + /** + * Prunes partitions not involve the query plan. + * + * @param partitions All partitions of the relation. + * @return Partitions that are involved in the query plan. + */ + private[catalyst] def prunePartitions(partitions: Seq[HivePartition]) = { + boundPruningPred match { + case None => partitions + case Some(shouldKeep) => partitions.filter { part => + val dataTypes = relation.partitionKeys.map(_.dataType) + val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { + castFromString(value, dataType) + } + + // Only partitioned values are needed here, since the predicate has already been bound to + // partition key attribute references. + val row = new GenericRow(castedValues) + Evaluate(shouldKeep, Seq(row)).asInstanceOf[Boolean] + } + } } def execute() = { inputRdd.map { row => val values = row match { - case Array(struct: LazyStruct, partitionKeys: Array[String]) => - attributeFunctions.map(_(struct, partitionKeys)) - case struct: LazyStruct => - attributeFunctions.map(_(struct, Array.empty)) + case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) => + attributeFunctions.map(_(deserializedRow, partitionKeys)) + case deserializedRow: AnyRef => + attributeFunctions.map(_(deserializedRow, Array.empty)) } buildRow(values.map { - case "NULL" => null - case "null" => null + case n: String if n.toLowerCase == "null" => null case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue case other => other }) @@ -88,12 +151,7 @@ case class InsertIntoHiveTable( */ val desc = new FileSinkDesc("./", table.tableDesc, false) - val outputClass = { - val serializer = - table.tableDesc.getDeserializerClass.newInstance().asInstanceOf[AbstractSerDe] - serializer.initialize(null, table.tableDesc.getProperties) - serializer.getSerializedClass - } + val outputClass = newSerializer(table.tableDesc).getSerializedClass lazy val conf = new JobConf() @@ -105,30 +163,66 @@ case class InsertIntoHiveTable( new Path((new org.apache.hadoop.fs.RawLocalFileSystem).getWorkingDirectory, "test.out"), null) + private def newSerializer(tableDesc: TableDesc) = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[AbstractSerDe] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + override def otherCopyArgs = sc :: Nil def output = child.output + /** + * Inserts all the rows in the table into Hive. Row objects are properly serialized with the + * `org.apache.hadoop.hive.serde2.SerDe` and the + * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + */ def execute() = { require(partition.isEmpty, "Inserting into partitioned table not supported.") val childRdd = child.execute() assert(childRdd != null) - // TODO: write directly to hive - val tempDir = java.io.File.createTempFile("data", "tsv") - tempDir.delete() - tempDir.mkdir() - childRdd.map(_.map(a => stringOrNull(a.asInstanceOf[AnyRef])).mkString("\001")) - .saveAsTextFile(tempDir.getCanonicalPath) + // TODO write directly to Hive + val tempDir = Files.createTempDirectory("data").toFile + + // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer + // instances within the closure, since AbstractSerDe is not serializable while TableDesc is. + val tableDesc = table.tableDesc + childRdd.mapPartitions { iter => + val serializer = newSerializer(tableDesc) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector(serializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + iter.map { row => + // Casts Strings to HiveVarchars when necessary. + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) + val mappedRow = row.zip(fieldOIs).map { + case (s: String, oi: JavaHiveVarcharObjectInspector) => new HiveVarchar(s, s.size) + case (obj, _) => obj + } - val partitionSpec = - if (partition.nonEmpty) { - s"PARTITION (${partition.map { case (k,v) => s"$k=${v.get}" }.mkString(",")})" + (null, serializer.serialize(Array(mappedRow: _*), standardOI)) + } + }.saveAsHadoopFile( + tempDir.getCanonicalPath, + classOf[NullWritable], + outputClass, + tableDesc.getOutputFileFormatClass) + + val partitionSpec = if (partition.nonEmpty) { + partition.map { + case (k, Some(v)) => s"$k=$v" + // Dynamic partition inserts + case (k, None) => s"$k" + }.mkString(" PARTITION (", ", ", ")") } else { "" } - sc.runHive(s"LOAD DATA LOCAL INPATH '${tempDir.getCanonicalPath}/*' INTO TABLE ${table.tableName} $partitionSpec") + val inpath = tempDir.getCanonicalPath + "/*" + sc.runHive(s"LOAD DATA LOCAL INPATH '$inpath' INTO TABLE ${table.tableName}$partitionSpec") // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which @@ -136,4 +230,4 @@ case class InsertIntoHiveTable( // TODO: implement hive compatibility as rules. sc.makeRDD(Nil, 1) } -} \ No newline at end of file +} diff --git a/src/main/scala/catalyst/execution/planningStrategies.scala b/src/main/scala/catalyst/execution/planningStrategies.scala index c631f3ebba512..2a26269daea1c 100644 --- a/src/main/scala/catalyst/execution/planningStrategies.scala +++ b/src/main/scala/catalyst/execution/planningStrategies.scala @@ -23,9 +23,9 @@ trait PlanningStrategies { def apply(plan: LogicalPlan): Seq[SharkPlan] = plan match { // Push attributes into table scan when possible. case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) => - execution.HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m) :: Nil + execution.HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None) :: Nil case m: MetastoreRelation => - execution.HiveTableScan(m.output, m) :: Nil + execution.HiveTableScan(m.output, m, None) :: Nil case _ => Nil } @@ -34,10 +34,40 @@ trait PlanningStrategies { * complex expressions. */ def isSimpleProject(projectList: Seq[NamedExpression]) = { - projectList.map { - case a: Attribute => true - case _ => false - }.reduceLeft(_ && _) + projectList.forall(_.isInstanceOf[Attribute]) + } + } + + /** + * A strategy used to detect filtering predicates on top of a partitioned relation to help + * partition pruning. + * + * This strategy itself doesn't perform partition pruning, it just collects and combines all the + * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator, + * which does the actual pruning work. + */ + object PartitionPrunings extends Strategy { + def apply(plan: LogicalPlan): Seq[SharkPlan] = plan match { + case p @ FilteredOperation(predicates, relation: MetastoreRelation) + if relation.hiveQlTable.isPartitioned => + + val partitionKeyIds = relation.partitionKeys.map(_.id).toSet + + // Filter out all predicates that only deal with partition keys + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.id).subsetOf(partitionKeyIds) + } + + val scan = execution.HiveTableScan( + relation.output, relation, pruningPredicates.reduceLeftOption(And)) + + otherPredicates + .reduceLeftOption(And) + .map(execution.Filter(_, scan)) + .getOrElse(scan) :: Nil + + case _ => + Nil } } diff --git a/src/main/scala/catalyst/expressions/BoundAttribute.scala b/src/main/scala/catalyst/expressions/BoundAttribute.scala index edb17a9b5a6e4..398d34fcc4449 100644 --- a/src/main/scala/catalyst/expressions/BoundAttribute.scala +++ b/src/main/scala/catalyst/expressions/BoundAttribute.scala @@ -4,7 +4,7 @@ package expressions import rules._ import errors._ -import execution.SharkPlan +import catalyst.plans.QueryPlan /** * A bound reference points to a specific slot in the input tuple, allowing the actual value to be retrieved more @@ -27,27 +27,37 @@ case class BoundReference(inputTuple: Int, ordinal: Int, baseReference: Attribut override def toString = s"$baseReference:$inputTuple.$ordinal" } -// TODO: Should run against any query plan, not just SharkPlans -object BindReferences extends Rule[SharkPlan] { - def apply(plan: SharkPlan): SharkPlan = { +class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { + import BindReferences._ + + def apply(plan: TreeNode): TreeNode = { plan.transform { - case leafNode: SharkPlan if leafNode.children.isEmpty => leafNode - case nonLeaf: SharkPlan => attachTree(nonLeaf, "Binding references in operator") { - logger.debug(s"Binding references in node ${nonLeaf.simpleString}") - nonLeaf.transformExpressions { - case a: AttributeReference => attachTree(a, "Binding attribute") { - val inputTuple = nonLeaf.children.indexWhere(_.output contains a) - val ordinal = if (inputTuple == -1) -1 else nonLeaf.children(inputTuple).output.indexWhere(_ == a) - if (ordinal == -1) { - logger.debug(s"No binding found for $a given input ${nonLeaf.children.map(_.output.mkString("{", ",", "}")).mkString(",")}") - a - } else { - logger.debug(s"Binding $a to $inputTuple.$ordinal given input ${nonLeaf.children.map(_.output.mkString("{", ",", "}")).mkString(",")}") - BoundReference(inputTuple, ordinal, a) - } - } + case leafNode if leafNode.children.isEmpty => leafNode + case nonLeaf => nonLeaf.transformExpressions { case e => + bindReference(e, nonLeaf.children.map(_.output)) + } + } + } +} + +object BindReferences extends Logging { + def bindReference(expression: Expression, input: Seq[Seq[Attribute]]): Expression = { + expression.transform { case a: AttributeReference => + attachTree(a, "Binding attribute") { + def inputAsString = input.map(_.mkString("{", ",", "}")).mkString(",") + + for { + (tuple, inputTuple) <- input.zipWithIndex + (attr, ordinal) <- tuple.zipWithIndex + if attr == a + } { + logger.debug(s"Binding $attr to $inputTuple.$ordinal given input $inputAsString") + return BoundReference(inputTuple, ordinal, a) } + + logger.debug(s"No binding found for $a given input $inputAsString") + a } } } -} \ No newline at end of file +} diff --git a/src/main/scala/catalyst/plans/logical/basicOperators.scala b/src/main/scala/catalyst/plans/logical/basicOperators.scala index 81ae69c2bd14a..9756672c85654 100644 --- a/src/main/scala/catalyst/plans/logical/basicOperators.scala +++ b/src/main/scala/catalyst/plans/logical/basicOperators.scala @@ -41,6 +41,10 @@ case class InsertIntoTable(table: BaseRelation, partition: Map[String, Option[St def children = table :: child :: Nil def references = Set.empty def output = child.output + + override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType + } } case class InsertIntoCreatedTable(tableName: String, child: LogicalPlan) extends UnaryNode { diff --git a/src/test/scala/catalyst/execution/HiveSerDeSuite.scala b/src/test/scala/catalyst/execution/HiveSerDeSuite.scala new file mode 100644 index 0000000000000..2fea958cc7cac --- /dev/null +++ b/src/test/scala/catalyst/execution/HiveSerDeSuite.scala @@ -0,0 +1,14 @@ +package catalyst.execution + +/** + * A set of tests that validates support for Hive SerDe. + */ +class HiveSerDeSuite extends HiveComparisonTest { + createQueryTest( + "Read and write with LazySimpleSerDe (tab separated)", + "SELECT * from serdeins") + + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") + + createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") +} diff --git a/src/test/scala/catalyst/execution/PartitionPruningSuite.scala b/src/test/scala/catalyst/execution/PartitionPruningSuite.scala new file mode 100644 index 0000000000000..f13317ddc9028 --- /dev/null +++ b/src/test/scala/catalyst/execution/PartitionPruningSuite.scala @@ -0,0 +1,51 @@ +package catalyst.execution + +import scala.collection.JavaConversions._ + +import TestShark._ + +class PartitionPruningSuite extends HiveComparisonTest { + createPruningTest("Pruning with predicate on STRING partition key", + "SELECT * FROM srcpart1 WHERE ds = '2008-04-08'", + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-08", "12"))) + + createPruningTest("Pruning with predicate on INT partition key", + "SELECT * FROM srcpart1 WHERE hr < 12", + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + createPruningTest("Select only 1 partition", + "SELECT * FROM srcpart1 WHERE ds = '2008-04-08' AND hr < 12", + Seq( + Seq("2008-04-08", "11"))) + + createPruningTest("All partitions pruned", + "SELECT * FROM srcpart1 WHERE ds = '2014-01-27' AND hr = 11", + Seq.empty) + + createPruningTest("Pruning with both column key and partition key", + "SELECT * FROM srcpart1 WHERE value IS NOT NULL AND hr < 12", + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + def createPruningTest(testCaseName: String, sql: String, expectedValues: Seq[Seq[String]]) = { + test(testCaseName) { + val plan = sql.q.executedPlan + val prunedPartitions = plan.collect { + case p @ HiveTableScan(_, relation, _) => + p.prunePartitions(relation.hiveQlPartitions) + }.head + val values = prunedPartitions.map(_.getValues) + + assert(prunedPartitions.size === expectedValues.size) + + for ((actual, expected) <- values.zip(expectedValues)) { + assert(actual sameElements expected) + } + } + } +}