Skip to content

Commit

Permalink
Migrates basic inspection and typed relational DF operations to DS
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Feb 29, 2016
1 parent 3f59569 commit ee9c432
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Modifier

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -559,7 +561,11 @@ class Analyzer(
}

resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
case n: NewInstance
if n.outerPointer.isEmpty &&
n.cls.isMemberClass &&
!Modifier.isStatic(n.cls.getModifiers) =>
n.cls.getEnclosingClass
val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
if (outer == null) {
throw new AnalysisException(
Expand Down
210 changes: 202 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAlias}
import org.apache.spark.sql.catalyst.encoders._
Expand Down Expand Up @@ -184,13 +184,6 @@ class Dataset[T] private[sql](
new Dataset(sqlContext, queryExecution, encoderFor[U])
}

/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))

/**
* Converts this strongly typed collection of data to generic Dataframe. In contrast to the
* strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
Expand Down Expand Up @@ -320,6 +313,32 @@ class Dataset[T] private[sql](
Repartition(numPartitions, shuffle = true, _)
}

/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions into
* `numPartitions`. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withPlan {
RepartitionByExpression(partitionExprs.map(_.expr), _, Some(numPartitions))
}

/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving
* the existing number of partitions. The resulting DataFrame is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = withPlan {
RepartitionByExpression(partitionExprs.map(_.expr), _, numPartitions = None)
}

/**
* Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down Expand Up @@ -522,6 +541,13 @@ class Dataset[T] private[sql](
* Typed Relational *
* ****************** */

/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))

/**
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
Expand Down Expand Up @@ -621,6 +647,162 @@ class Dataset[T] private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
* @since 2.0.0
*/
def filter(condition: Column): Dataset[T] = withPlan(Filter(condition.expr, _))

/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.filter("age > 15")
* }}}
* @since 2.0.0
*/
def filter(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}

/**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
* @since 2.0.0
*/
def where(condition: Column): Dataset[T] = filter(condition)

/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.where("age > 15")
* }}}
* @since 2.0.0
*/
def where(conditionExpr: String): Dataset[T] = filter(conditionExpr)

/**
* Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function
* and `head` is that `head` returns an array while `limit` returns a new [[Dataset]].
* @since 2.0.0
*/
def limit(n: Int): Dataset[T] = withPlan(Limit(Literal(n), _))

/**
* Returns a new [[Dataset]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
}

/**
* Returns a new [[Dataset]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}

/**
* Returns a new [[Dataset]] sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* ds.sort("sortcol")
* ds.sort($"sortcol")
* ds.sort($"sortcol".asc)
* }}}
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): Dataset[T] = {
sort((sortCol +: sortCols).map(apply) : _*)
}

/**
* Returns a new [[Dataset]] sorted by the given expressions. For example:
* {{{
* ds.sort($"col1", $"col2".desc)
* }}}
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}

/**
* Returns a new [[Dataset]] sorted by the given expressions.
* This is an alias of the `sort` function.
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)

/**
* Returns a new [[Dataset]] sorted by the given expressions.
* This is an alias of the `sort` function.
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)

/**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @since 2.0.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
// It is possible that the underlying Dataset doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
}.toArray
}

/**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @since 2.0.0
*/
def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*/
private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
randomSplit(weights.toArray, seed)
}

/* **************** *
* Set operations *
* **************** */
Expand Down Expand Up @@ -845,4 +1027,16 @@ class Dataset[T] private[sql](
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
}

private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
withPlan(Sort(sortOrder, global = global, _))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1154,11 +1154,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(original.rdd.partitions.length == 1)
val df = original.repartition(5, $"key")
assert(df.rdd.partitions.length == 5)
checkAnswer(original.select(), df.select())
checkAnswer(original, df)

val df2 = original.repartition(10, $"key")
assert(df2.rdd.partitions.length == 10)
checkAnswer(original.select(), df2.select())
checkAnswer(original, df2)

// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
Expand Down
Loading

0 comments on commit ee9c432

Please sign in to comment.