Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL][DataFrame] Fix column computability bug. #4519

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
import sqlContext.implicits.createDataFrame
import sqlContext.implicits._
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
Expand Down
35 changes: 26 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,44 @@ trait Column extends DataFrame {
*/
def isComputable: Boolean

/** Removes the top project so we can get to the underlying plan. */
private def stripProject(p: LogicalPlan): LogicalPlan = p match {
case Project(_, child) => child
case p => sys.error("Unexpected logical plan (expected Project): " + p)
}

private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
val plan = Project(Seq(expr match {
val namedExpr = expr match {
case named: NamedExpression => named
case unnamed: Expression => Alias(unnamed, "col")()
}), baseCol.plan)
}
val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
Column(baseCol.sqlContext, plan, expr)
}

/**
* Construct a new column based on the expression and the other column value.
*
* There are two cases that can happen here:
* If otherValue is a constant, it is first turned into a Column.
* If otherValue is a Column, then:
* - If this column and otherValue are both computable and come from the same logical plan,
* then we can construct a ComputableColumn by applying a Project on top of the base plan.
* - If this column is not computable, but otherValue is computable, then we can construct
* a ComputableColumn based on otherValue's base plan.
* - If this column is computable, but otherValue is not, then we can construct a
* ComputableColumn based on this column's base plan.
* - If neither columns are computable, then we create an IncomputableColumn.
*/
private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
// Removes all the top level projection and subquery so we can get to the underlying plan.
@tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
case Project(_, child) => stripProject(child)
case Subquery(_, child) => stripProject(child)
case _ => p
}

// lit(otherValue) returns a Column always.
(this, lit(otherValue)) match {
case (left: ComputableColumn, right: ComputableColumn) =>
if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
computableCol(right, newExpr(right))
} else {
// We don't want to throw an exception here because "df1("a") === df2("b")" can be
// a valid expression for join conditions, even though standalone they are not valid.
Column(newExpr(right))
}
case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}

/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}


Expand All @@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest {
shouldBeComputable(-testData2("a"))
shouldBeComputable(!testData2("a"))

shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
shouldBeComputable(
shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
shouldNotBeComputable(
testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
shouldBeComputable(
shouldNotBeComputable(
testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))

// Literals and unresolved columns should not be computable.
Expand All @@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest {
shouldNotBeComputable(sum(testData2("a")))
}

test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
checkAnswer(df("a") + df("b"), Seq(Row(3)))
checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
}

test("star") {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}
Expand Down