Skip to content

Commit

Permalink
[SPARK-5573][SQL] Add explode to dataframes
Browse files Browse the repository at this point in the history
Author: Michael Armbrust <[email protected]>

Closes #4546 from marmbrus/explode and squashes the following commits:

eefd33a [Michael Armbrust] whitespace
a8d496c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into explode
4af740e [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explode
dc86a5c [Michael Armbrust] simple version
d633d01 [Michael Armbrust] add scala specific
950707a [Michael Armbrust] fix comments
ba8854c [Michael Armbrust] [SPARK-5573][SQL] Add explode to dataframes

(cherry picked from commit ee04a8b)
Signed-off-by: Michael Armbrust <[email protected]>
  • Loading branch information
marmbrus committed Feb 12, 2015
1 parent b0c79da commit c7eb9ee
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ abstract class Generator extends Expression {
}
}

/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema

override def eval(input: Row): TraversableOnce[Row] = {
val inputRow = new InterpretedProjection(children)
function(inputRow(input))
}

override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
}

/**
* Given an input array produces a sequence of rows for each value in the array.
*/
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.annotation.{DeveloperApi, Experimental}
Expand Down Expand Up @@ -441,6 +442,43 @@ trait DataFrame extends RDDApi[Row] with Serializable {
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
*
* The following example uses this function to count the number of books which contain
* a given word:
*
* {{{
* case class Book(title: String, words: String)
* val df: RDD[Book]
*
* case class Word(word: String)
* val allWords = df.explode('words) {
* case Row(words: String) => words.split(" ").map(Word(_))
* }
*
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
* }}}
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame


/**
* (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
* or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
* columns of the input row are implicitly joined with each value that is output by the function.
*
* {{{
* df.explode("words", "word")(words: String => words.split(" "))
* }}}
*/
def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame

/////////////////////////////////////////////////////////////////////////////

/**
Expand Down
30 changes: 28 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.CharArrayWriter

import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConversions._

import com.fasterxml.jackson.core.JsonFactory
Expand All @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand All @@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}


/**
* Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
*/
Expand Down Expand Up @@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}

override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes
val rowFunction =
f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, None, logicalPlan)
}

override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
def rowFunction(row: Row) = {
f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)

}

/////////////////////////////////////////////////////////////////////////////
// RDD API
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -110,6 +111,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten

override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()

override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()

override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////

override def head(n: Int): Array[Row] = err()
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")

checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
)
}

test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
}

checkAnswer(
df2
.select('_1 as 'letter, 'number)
.groupBy('letter)
.agg('letter, countDistinct('number)),
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
)
}

test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
Expand Down

0 comments on commit c7eb9ee

Please sign in to comment.