Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL] Various DataFrame DSL update. #4260

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
dataset.select($"*", callUDF(
this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val score: Vector => Double = (v) => {
val scoreFunction: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
}
val t = map(threshold)
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
.select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
dataset
.select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: The word col might be used as matrix column index in ML algorithms.

This line is still not straightforward to read. I'm thinking of something like the following

val scoreFunc = UDF((score: Double) => {if (score > t) 1.0 else 0.0})
dataset.select($"*", scoreFunc(col(map(featuresCol))).as(map(scoreCol))

.select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}

/**
Expand Down Expand Up @@ -85,7 +84,7 @@ class StandardScalerModel private[ml] (
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
Expand Down
35 changes: 11 additions & 24 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,10 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext._
import org.apache.spark.ml.recommendation.ALSModel.Factor
import dataset.sqlContext.createDataFrame
val map = this.paramMap ++ paramMap
// TODO: Add DSL to simplify the code here.
val instanceTable = s"instance_$uid"
val userTable = s"user_$uid"
val itemTable = s"item_$uid"
val instances = dataset.as(instanceTable)
val users = userFactors.map { case (id, features) =>
Factor(id, features)
}.as(userTable)
val items = itemFactors.map { case (id, features) =>
Factor(id, features)
}.as(itemTable)
val users = userFactors.toDataFrame("id", "features")
val items = itemFactors.toDataFrame("id", "features")
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
Expand All @@ -133,24 +123,21 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
.as(map(predictionCol))
val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
instances
.join(users, Column(map(userCol)) === $"$userTable.id", "left")
.join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
dataset
.join(users, dataset(map(userCol)) === users("id"), "left")
.join(items, dataset(map(itemCol)) === items("id"), "left")
.select(outputColumns: _*)
// TODO: Just use a dataset("*")
// .select(dataset("*"), prediction)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this could be .select(dataset("*"), predictionFunc(users("features"), items("features")).as(map(predictionCol))).

}

override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

private object ALSModel {
/** Case class to convert factors to [[DataFrame]]s */
private case class Factor(id: Int, features: Seq[Float])
}

/**
* Alternating Least Squares (ALS) matrix factorization.
Expand Down Expand Up @@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
val ratings = dataset
.select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}

import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (s: Array[_], arrayType: ArrayType) => s.toSeq
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
Expand Down Expand Up @@ -140,7 +141,9 @@ trait ScalaReflection {
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ case class OptionalData(

case class ComplexData(
arrayField: Seq[Int],
arrayField1: Array[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
Expand Down Expand Up @@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite {
"arrayField",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayField1",
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
Expand Down
12 changes: 8 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ import scala.language.implicitConversions
import org.apache.spark.sql.api.scala.dsl.lit
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.types._


object Column {
def unapply(col: Column): Option[Expression] = Some(col.expr)

/**
* Creates a [[Column]] based on the given column name.
* Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]].
*/
def apply(colName: String): Column = new Column(colName)

/** For internal pattern matching. */
private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
}


Expand Down Expand Up @@ -438,7 +442,7 @@ class Column(
* @param ordinal
* @return
*/
override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))

/**
* An expression that gets a field by name in a [[StructField]].
Expand Down
47 changes: 43 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class DataFrame protected[sql](

/** Resolves a column name into a Catalyst [[NamedExpression]]. */
protected[sql] def resolve(colName: String): NamedExpression = {
logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})"""))
}

/** Left here for compatibility reasons. */
Expand All @@ -131,6 +131,29 @@ class DataFrame protected[sql](
*/
def toDataFrame: DataFrame = this

/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
* rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
* rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
*/
@scala.annotation.varargs
def toDataFrame(colName: String, colNames: String*): DataFrame = {
val newNames = colName +: colNames
require(schema.size == newNames.size,
"The number of columns doesn't match.\n" +
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
"New column names: " + newNames.mkString(", "))

val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
apply(oldName).as(newName)
}
select(newCols :_*)
}

/** Returns the schema of this [[DataFrame]]. */
override def schema: StructType = queryExecution.analyzed.schema

Expand Down Expand Up @@ -227,7 +250,7 @@ class DataFrame protected[sql](
}

/**
* Selects a single column and return it as a [[Column]].
* Selects column based on the column name and return it as a [[Column]].
*/
override def apply(colName: String): Column = colName match {
case "*" =>
Expand Down Expand Up @@ -466,13 +489,29 @@ class DataFrame protected[sql](
rdd.map(f)
}

/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
*/
override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)

/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
rdd.mapPartitions(f)
}

/**
* Applies a function `f` to all rows.
*/
override def foreach(f: Row => Unit): Unit = rdd.foreach(f)

/**
* Applies a function f to each partition of this [[DataFrame]].
*/
override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)

/**
* Returns the first `n` rows in the [[DataFrame]].
*/
Expand Down Expand Up @@ -520,7 +559,7 @@ class DataFrame protected[sql](
/////////////////////////////////////////////////////////////////////////////

/**
* Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
*/
override def rdd: RDD[Row] = {
val schema = this.schema
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {

def map[R: ClassTag](f: T => R): RDD[R]

def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]

def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]

def foreach(f: T => Unit): Unit

def foreachPartition(f: Iterator[T] => Unit): Unit

def take(n: Int): Array[T]

def collect(): Array[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ public class dsl {

private static package$ scalaDsl = package$.MODULE$;

/**
* Returns a {@link Column} based on the given column name.
*/
public static Column col(String colName) {
return new Column(colName);
}

/**
* Creates a column of literal value.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -37,6 +38,21 @@ package object dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

// /**
// * An implicit conversion that turns a RDD of product into a [[DataFrame]].
// *
// * This method requires an implicit SQLContext in scope. For example:
// * {{{
// * implicit val sqlContext: SQLContext = ...
// * val rdd: RDD[(Int, String)] = ...
// * rdd.toDataFrame // triggers the implicit here
// * }}}
// */
// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
// : DataFrame = {
// context.createDataFrame(rdd)
// }

/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
Expand All @@ -46,6 +62,11 @@ package object dsl {

private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)

/**
* Returns a [[Column]] based on the given column name.
*/
def col(colName: String): Column = new Column(colName)

/**
* Creates a [[Column]] of literal value.
*/
Expand Down