Skip to content

Commit

Permalink
simple version
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Feb 12, 2015
1 parent d633d01 commit dc86a5c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 2 deletions.
17 changes: 16 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ trait DataFrame extends RDDApi[Row] {

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
*
* The following example uses this function to count the number of books which contain
Expand All @@ -481,6 +481,21 @@ trait DataFrame extends RDDApi[Row] {
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame


/**
* (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
* or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
* columns of the input row are implicitly joined with each value that is output by the function.
*
* {{{
* df.explode("words", "word")(words: String => words.split(" "))
* }}}
*/
def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame

/////////////////////////////////////////////////////////////////////////////

/**
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ private[sql] class DataFrameImpl protected[sql](
Generate(generator, join = true, outer = false, None, logicalPlan)
}

override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
def rowFunction(row: Row) = {
f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)

}


/////////////////////////////////////////////////////////////////////////////
// RDD API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////
override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////

override def head(n: Int): Array[Row] = err()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ class DataFrameSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")

checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
)
}

test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
val df2 =
Expand Down

0 comments on commit dc86a5c

Please sign in to comment.