From 5ad78f62056f2560cd371ee964111a646806d0ff Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 29 Jan 2015 00:01:10 -0800 Subject: [PATCH] [SQL] Various DataFrame DSL update. 1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update. --- .../org/apache/spark/ml/Transformer.scala | 3 +- .../classification/LogisticRegression.scala | 12 ++--- .../spark/ml/feature/StandardScaler.scala | 3 +- .../apache/spark/ml/recommendation/ALS.scala | 35 +++++--------- .../apache/spark/mllib/linalg/Vectors.scala | 3 +- .../spark/sql/catalyst/ScalaReflection.scala | 5 +- .../sql/catalyst/ScalaReflectionSuite.scala | 5 ++ .../scala/org/apache/spark/sql/Column.scala | 12 +++-- .../org/apache/spark/sql/DataFrame.scala | 47 +++++++++++++++++-- .../main/scala/org/apache/spark/sql/api.scala | 6 +++ .../org/apache/spark/sql/api/java/dsl.java | 7 +++ .../spark/sql/api/scala/dsl/package.scala | 21 +++++++++ 12 files changed, 114 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 29cd9810784bc..6eb7ea639c220 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ import org.apache.spark.sql.types._ @@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) + this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 101f6c8114559..d82360dcce148 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { + val scoreFunction: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 - } - dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) - .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } + dataset + .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c456beb65d884..78a48561ddf87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -85,7 +84,7 @@ class StandardScalerModel private[ml] ( val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) + dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 738b1844b5100..474d4731ec0de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -111,20 +111,10 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor + import dataset.sqlContext.createDataFrame val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(instanceTable) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(userTable) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(itemTable) + val users = userFactors.toDataFrame("id", "features") + val items = itemFactors.toDataFrame("id", "features") val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -133,13 +123,14 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features") - .as(map(predictionCol)) - val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction - instances - .join(users, Column(map(userCol)) === $"$userTable.id", "left") - .join(items, Column(map(itemCol)) === $"$itemTable.id", "left") + val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) + val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") .select(outputColumns: _*) + // TODO: Just use a dataset("*") + // .select(dataset("*"), prediction) } override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { @@ -147,10 +138,6 @@ class ALSModel private[ml] ( } } -private object ALSModel { - /** Case class to convert factors to [[DataFrame]]s */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap val ratings = dataset - .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType)) + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 31c33f1bf6fd0..567a8a6c03d90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10b5f..4def65b01f583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -57,6 +57,7 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (s: Array[_], arrayType: ArrayType) => s.toSeq case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -140,7 +141,9 @@ trait ScalaReflection { // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55daa..4a66716e0a782 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,6 +60,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7f9a91a032c28..9be2a03afafd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -22,15 +22,19 @@ import scala.language.implicitConversions import org.apache.spark.sql.api.scala.dsl.lit import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.types._ object Column { - def unapply(col: Column): Option[Expression] = Some(col.expr) - + /** + * Creates a [[Column]] based on the given column name. + * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]]. + */ def apply(colName: String): Column = new Column(colName) + + /** For internal pattern matching. */ + private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) } @@ -438,7 +442,7 @@ class Column( * @param ordinal * @return */ - override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) + override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) /** * An expression that gets a field by name in a [[StructField]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ceb5f86befe71..050366aea8c89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -118,8 +118,8 @@ class DataFrame protected[sql]( /** Resolves a column name into a Catalyst [[NamedExpression]]. */ protected[sql] def resolve(colName: String): NamedExpression = { - logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse( - throw new RuntimeException(s"""Cannot resolve column name "$colName"""")) + logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) } /** Left here for compatibility reasons. */ @@ -131,6 +131,29 @@ class DataFrame protected[sql]( */ def toDataFrame: DataFrame = this + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + */ + @scala.annotation.varargs + def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + /** Returns the schema of this [[DataFrame]]. */ override def schema: StructType = queryExecution.analyzed.schema @@ -227,7 +250,7 @@ class DataFrame protected[sql]( } /** - * Selects a single column and return it as a [[Column]]. + * Selects column based on the column name and return it as a [[Column]]. */ override def apply(colName: String): Column = colName match { case "*" => @@ -466,6 +489,12 @@ class DataFrame protected[sql]( rdd.map(f) } + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + /** * Returns a new RDD by applying a function to each partition of this DataFrame. */ @@ -473,6 +502,16 @@ class DataFrame protected[sql]( rdd.mapPartitions(f) } + /** + * Applies a function `f` to all rows. + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + /** * Returns the first `n` rows in the [[DataFrame]]. */ @@ -520,7 +559,7 @@ class DataFrame protected[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s. + * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. */ override def rdd: RDD[Row] = { val schema = this.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index 5eeaf17d71796..59634082f61c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] { def map[R: ClassTag](f: T => R): RDD[R] + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + def take(n: Int): Array[T] def collect(): Array[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java index 74d7649e08cf2..16702afdb31cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java @@ -32,6 +32,13 @@ public class dsl { private static package$ scalaDsl = package$.MODULE$; + /** + * Returns a {@link Column} based on the given column name. + */ + public static Column col(String colName) { + return new Column(colName); + } + /** * Creates a column of literal value. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala index 9f2d1427d4a62..dc851fc5048ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -37,6 +38,21 @@ package object dsl { /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) +// /** +// * An implicit conversion that turns a RDD of product into a [[DataFrame]]. +// * +// * This method requires an implicit SQLContext in scope. For example: +// * {{{ +// * implicit val sqlContext: SQLContext = ... +// * val rdd: RDD[(Int, String)] = ... +// * rdd.toDataFrame // triggers the implicit here +// * }}} +// */ +// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) +// : DataFrame = { +// context.createDataFrame(rdd) +// } + /** Converts $"col name" into an [[Column]]. */ implicit class StringToColumn(val sc: StringContext) extends AnyVal { def $(args: Any*): ColumnName = { @@ -46,6 +62,11 @@ package object dsl { private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + /** + * Returns a [[Column]] based on the given column name. + */ + def col(colName: String): Column = new Column(colName) + /** * Creates a [[Column]] of literal value. */