Skip to content

Commit

Permalink
Update the code as feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Apr 17, 2015
1 parent ca5e7f4 commit d2e8b43
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,46 +474,60 @@ class Analyzer(
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(Seq(Alias(g: Generator, name)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil)
Generate(g, join = false, outer = false,
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, names, Nil)
Generate(g, join = false, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), child)
}
}

/**
* Resolve the Generate, if the output names specified, we will take them, otherwise
* we will try to provide the default names, which follow the same rule with Hive.
*/
object ResolveGenerate extends Rule[LogicalPlan] {
// Construct the output attributes for the generator,
// The output attribute names can be either specified or
// auto generated.
private def makeGeneratorOutput(
generator: Generator,
attributeNames: Seq[String],
qualifier: Option[String]): Array[Attribute] = {
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
val elementTypes = generator.elementTypes

val raw = if (attributeNames.size == elementTypes.size) {
attributeNames.zip(elementTypes).map {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
if (generatorOutput.size == elementTypes.size) {
generatorOutput.zip(elementTypes).map {
case (a, (t, nullable)) if !a.resolved =>
AttributeReference(a.name, t, nullable)()
case (a, _) => a
}
} else {
} else if (generatorOutput.length == 0) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
throw new AnalysisException(
s"""
|The number of aliases supplied in the AS clause does not match
|the number of columns output by the UDTF expected
|${elementTypes.size} aliases but got ${generatorOutput.size}
""".stripMargin)
}

qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute]
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case p: Generate if p.resolved == false =>
// if the generator output names are not specified, we will use the default ones.
val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier)
Generate(
p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput)
p.generator,
join = p.join,
outer = p.outer,
p.qualifier,
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ package object dsl {
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): Generate =
Generate(generator, join, outer, logicalPlan, alias)
alias: Option[String] = None): LogicalPlan =
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)

def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ abstract class Generator extends Expression {

override type EvaluatedType = TraversableOnce[Row]

override def dataType: DataType = ???
// TODO ideally we should return the type of ArrayType(StructType),
// however, we don't keep the output field names in the Generator.
override def dataType: DataType = throw new UnsupportedOperationException

override def nullable: Boolean = false

/**
* The output element data types in structure of Seq[(DataType, Nullable)]
* TODO we probably need to add more information like metadata etc.
*/
def elementTypes: Seq[(DataType, Boolean)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput)
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,35 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param child Children logical plan node
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if
* leave as default (empty)
* @param gOutput The output of Generator.
* @param generatorOutput The output schema of the Generator.
* @param child Children logical plan node
*/
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
child: LogicalPlan,
qualifier: Option[String] = None,
attributeNames: Seq[String] = Nil,
gOutput: Seq[Attribute] = Nil)
qualifier: Option[String],
generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {

override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
attributeNames.length > 0 &&
gOutput.map(_.name) == attributeNames
!generatorOutput.exists(!_.resolved)
}

// we don't want the gOutput to be taken as part of the expressions
// as that will cause exceptions like unresolved attributes etc.
override def expressions: Seq[Expression] = generator :: Nil

def output: Seq[Attribute] = {
if (join) child.output ++ gOutput else gOutput
val withoutQualifier = if (join) child.output ++ generatorOutput else generatorOutput

qualifier.map(q =>
withoutQualifier.map(_.withQualifiers(q :: Nil))
).getOrElse(withoutQualifier)
}
}

Expand Down
8 changes: 5 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -719,7 +719,8 @@ class DataFrame private[sql](
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}

/**
Expand All @@ -745,7 +746,8 @@ class DataFrame private[sql](
}
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
Expand All @@ -39,8 +40,8 @@ case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
child: SparkPlan,
output: Seq[Attribute])
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {

val boundGenerator = BindReferences.bindReference(generator, child.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
case g @ logical.Generate(generator, join, outer, child, _, _, _) =>
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.Generate(
generator, join = join, outer = outer, planLater(child), g.output) :: Nil
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
Expand Down
10 changes: 4 additions & 6 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -730,10 +730,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
generator,
join = true,
outer = false,
withWhere,
Some(alias.toLowerCase),
attributes,
Nil)
attributes.map(UnresolvedAttribute(_)),
withWhere)
}.getOrElse(withWhere)

// The projection of the query can either be a normal projection, an aggregation
Expand Down Expand Up @@ -841,10 +840,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
generator,
join = true,
outer = isOuter.nonEmpty,
nodeToRelation(relationClause),
Some(alias.toLowerCase),
attributes,
Nil)
attributes.map(UnresolvedAttribute(_)),
nodeToRelation(relationClause))

/* All relations, possibly with aliases or sampling clauses. */
case Token("TOK_TABREF", clauses) =>
Expand Down

0 comments on commit d2e8b43

Please sign in to comment.