diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d367eff845b88..35c3c82551f36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -474,46 +474,60 @@ class Analyzer( object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(Seq(Alias(g: Generator, name)), child) => - Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil) + Generate(g, join = false, outer = false, + qualifier = None, UnresolvedAttribute(name) :: Nil, child) case Project(Seq(MultiAlias(g: Generator, names)), child) => - Generate(g, join = false, outer = false, child, qualifier = None, names, Nil) + Generate(g, join = false, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), child) } } + /** + * Resolve the Generate, if the output names specified, we will take them, otherwise + * we will try to provide the default names, which follow the same rule with Hive. + */ object ResolveGenerate extends Rule[LogicalPlan] { // Construct the output attributes for the generator, // The output attribute names can be either specified or // auto generated. private def makeGeneratorOutput( generator: Generator, - attributeNames: Seq[String], - qualifier: Option[String]): Array[Attribute] = { + generatorOutput: Seq[Attribute]): Seq[Attribute] = { val elementTypes = generator.elementTypes - val raw = if (attributeNames.size == elementTypes.size) { - attributeNames.zip(elementTypes).map { - case (n, (t, nullable)) => AttributeReference(n, t, nullable)() + if (generatorOutput.size == elementTypes.size) { + generatorOutput.zip(elementTypes).map { + case (a, (t, nullable)) if !a.resolved => + AttributeReference(a.name, t, nullable)() + case (a, _) => a } - } else { + } else if (generatorOutput.length == 0) { elementTypes.zipWithIndex.map { // keep the default column names as Hive does _c0, _c1, _cN case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() } + } else { + throw new AnalysisException( + s""" + |The number of aliases supplied in the AS clause does not match + |the number of columns output by the UDTF expected + |${elementTypes.size} aliases but got ${generatorOutput.size} + """.stripMargin) } - - qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute] } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Generate if !p.child.resolved || !p.generator.resolved => p case p: Generate if p.resolved == false => // if the generator output names are not specified, we will use the default ones. - val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier) Generate( - p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput) + p.generator, + join = p.join, + outer = p.outer, + p.qualifier, + makeGeneratorOutput(p.generator, p.generatorOutput), p.child) } } - } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fc87b34bb1d3f..4e5c64bb63c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -289,8 +289,8 @@ package object dsl { generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): Generate = - Generate(generator, join, outer, logicalPlan, alias) + alias: Option[String] = None): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3257fbc67e368..ae8624fb04e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -42,12 +42,15 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] - override def dataType: DataType = ??? + // TODO ideally we should return the type of ArrayType(StructType), + // however, we don't keep the output field names in the Generator. + override def dataType: DataType = throw new UnsupportedOperationException override def nullable: Boolean = false /** * The output element data types in structure of Seq[(DataType, Nullable)] + * TODO we probably need to add more information like metadata etc. */ def elementTypes: Seq[(DataType, Boolean)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f5c71ee1da21c..55dfcbf7cd143 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -486,7 +486,7 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, - Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput) + g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6acb83c5e32de..ca0a700869a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -45,27 +45,23 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param child Children logical plan node * @param qualifier Qualifier for the attributes of generator(UDTF) - * @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if - * leave as default (empty) - * @param gOutput The output of Generator. + * @param generatorOutput The output schema of the Generator. + * @param child Children logical plan node */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - child: LogicalPlan, - qualifier: Option[String] = None, - attributeNames: Seq[String] = Nil, - gOutput: Seq[Attribute] = Nil) + qualifier: Option[String], + generatorOutput: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = { generator.resolved && childrenResolved && - attributeNames.length > 0 && - gOutput.map(_.name) == attributeNames + !generatorOutput.exists(!_.resolved) } // we don't want the gOutput to be taken as part of the expressions @@ -73,7 +69,11 @@ case class Generate( override def expressions: Seq[Expression] = generator :: Nil def output: Seq[Attribute] = { - if (join) child.output ++ gOutput else gOutput + val withoutQualifier = if (join) child.output ++ generatorOutput else generatorOutput + + qualifier.map(q => + withoutQualifier.map(_.withQualifiers(q :: Nil)) + ).getOrElse(withoutQualifier) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 858b126f8df32..8aadc0fafa971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -719,7 +719,8 @@ class DataFrame private[sql]( f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } /** @@ -745,7 +746,8 @@ class DataFrame private[sql]( } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 98c0d4c9b5cf2..5201e20a10565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the @@ -39,8 +40,8 @@ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - child: SparkPlan, - output: Seq[Attribute]) + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { val boundGenerator = BindReferences.bindReference(generator, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index fee3008d18050..29841bf3e0c72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -304,9 +304,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case g @ logical.Generate(generator, join, outer, child, _, _, _) => + case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( - generator, join = join, outer = outer, planLater(child), g.output) :: Nil + generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e5ada7f5178e9..57e835df7bd46 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -730,10 +730,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C generator, join = true, outer = false, - withWhere, Some(alias.toLowerCase), - attributes, - Nil) + attributes.map(UnresolvedAttribute(_)), + withWhere) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation @@ -841,10 +840,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C generator, join = true, outer = isOuter.nonEmpty, - nodeToRelation(relationClause), Some(alias.toLowerCase), - attributes, - Nil) + attributes.map(UnresolvedAttribute(_)), + nodeToRelation(relationClause)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) =>