Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13553][SPARK-13554][SQL] Migrates basic inspection and typed relational operations from DataFrame to Dataset #11431

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 261 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAlias}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
import org.apache.spark.sql.execution.{ExplainCommand, Queryable, QueryExecution}
import org.apache.spark.sql.types.StructType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -91,6 +91,10 @@ class Dataset[T] private[sql](
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)

/* ***************************** *
* Basic Inspection Operations *
* ***************************** */

/**
* Returns the schema of the encoded form of the objects in this [[Dataset]].
* @since 1.6.0
Expand All @@ -101,19 +105,62 @@ class Dataset[T] private[sql](
* Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
// scalastyle:off println
override def printSchema(): Unit = println(schema.treeString)
// scalastyle:on println

/**
* Prints the plans (logical and physical) to the console for debugging purposes.
* @since 1.6.0
*/
override def explain(extended: Boolean): Unit = toDF().explain(extended)
override def explain(extended: Boolean): Unit = {
val explain = ExplainCommand(queryExecution.logical, extended = extended)
sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
// scalastyle:off println
r => println(r.getString(0))
// scalastyle:on println
}
}

/**
* Prints the physical plan to the console for debugging purposes.
* @since 1.6.0
*/
override def explain(): Unit = toDF().explain()
override def explain(): Unit = explain(extended = false)

/**
* Returns all column names and their data types as an array.
* @since 2.0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the API is moved from DataFrame, should we also copy the @since? cc @rxin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put 2.0 since it didn't exist on dataset before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since DataFrame will be an alias of Dataset, what will the doc for DataFrame looks like?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we copy the versions, It's also weird that see a method of Dataset (1.3) is introduced before Dataset is introduced (1.6).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I'd just have it as 2.0.

*/
def dtypes: Array[(String, String)] = schema.fields.map { field =>
(field.name, field.dataType.toString)
}

/**
* Returns all column names as an array.
* @since 2.0.0
*/
def columns: Array[String] = schema.fields.map(_.name)

/**
* Selects column based on the column name and return it as a [[Column]].
* Note that the column name can also reference to a nested column like `a.b`.
* @since 2.0.0
*/
def apply(colName: String): Column = col(colName)

/**
* Selects column based on the column name and return it as a [[Column]].
* Note that the column name can also reference to a nested column like `a.b`.
* @since 2.0.0
*/
def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
val expr = resolve(colName)
Column(expr)
}

/* ************* *
* Conversions *
Expand All @@ -137,13 +184,6 @@ class Dataset[T] private[sql](
new Dataset(sqlContext, queryExecution, encoderFor[U])
}

/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))

/**
* Converts this strongly typed collection of data to generic Dataframe. In contrast to the
* strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
Expand Down Expand Up @@ -273,6 +313,32 @@ class Dataset[T] private[sql](
Repartition(numPartitions, shuffle = true, _)
}

/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions into
* `numPartitions`. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withPlan {
RepartitionByExpression(partitionExprs.map(_.expr), _, Some(numPartitions))
}

/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving
* the existing number of partitions. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = withPlan {
RepartitionByExpression(partitionExprs.map(_.expr), _, numPartitions = None)
}

/**
* Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down Expand Up @@ -475,6 +541,13 @@ class Dataset[T] private[sql](
* Typed Relational *
* ****************** */

/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))

/**
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
Expand Down Expand Up @@ -574,6 +647,162 @@ class Dataset[T] private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
* @since 2.0.0
*/
def filter(condition: Column): Dataset[T] = withPlan(Filter(condition.expr, _))

/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.filter("age > 15")
* }}}
* @since 2.0.0
*/
def filter(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}

/**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
* @since 2.0.0
*/
def where(condition: Column): Dataset[T] = filter(condition)

/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.where("age > 15")
* }}}
* @since 2.0.0
*/
def where(conditionExpr: String): Dataset[T] = filter(conditionExpr)

/**
* Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function
* and `head` is that `head` returns an array while `limit` returns a new [[Dataset]].
* @since 2.0.0
*/
def limit(n: Int): Dataset[T] = withPlan(Limit(Literal(n), _))

/**
* Returns a new [[Dataset]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
}

/**
* Returns a new [[Dataset]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}

/**
* Returns a new [[Dataset]] sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* ds.sort("sortcol")
* ds.sort($"sortcol")
* ds.sort($"sortcol".asc)
* }}}
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): Dataset[T] = {
sort((sortCol +: sortCols).map(apply) : _*)
}

/**
* Returns a new [[Dataset]] sorted by the given expressions. For example:
* {{{
* ds.sort($"col1", $"col2".desc)
* }}}
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}

/**
* Returns a new [[Dataset]] sorted by the given expressions.
* This is an alias of the `sort` function.
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)

/**
* Returns a new [[Dataset]] sorted by the given expressions.
* This is an alias of the `sort` function.
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)

/**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @since 2.0.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
// It is possible that the underlying Dataset doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass encoder into newly created Datasets at here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there's an implicit encoder defined in the constructor of Dataset.

}.toArray
}

/**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @since 2.0.0
*/
def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*/
private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
randomSplit(weights.toArray, seed)
}

/* **************** *
* Set operations *
* **************** */
Expand Down Expand Up @@ -791,4 +1020,23 @@ class Dataset[T] private[sql](
other: Dataset[_])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))

protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
}

private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
withPlan(Sort(sortOrder, global = global, _))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1154,11 +1154,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(original.rdd.partitions.length == 1)
val df = original.repartition(5, $"key")
assert(df.rdd.partitions.length == 5)
checkAnswer(original.select(), df.select())
checkAnswer(original, df)

val df2 = original.repartition(10, $"key")
assert(df2.rdd.partitions.length == 10)
checkAnswer(original.select(), df2.select())
checkAnswer(original, df2)

// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
Expand Down
Loading