Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5573][SQL] Add explode to dataframes #4546

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ abstract class Generator extends Expression {
}
}

/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema

override def eval(input: Row): TraversableOnce[Row] = {
val inputRow = new InterpretedProjection(children)
function(inputRow(input))
}

override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
}

/**
* Given an input array produces a sequence of rows for each value in the array.
*/
Expand Down
23 changes: 23 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.annotation.{DeveloperApi, Experimental}
Expand Down Expand Up @@ -458,6 +459,28 @@ trait DataFrame extends RDDApi[Row] {
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
*
* The following example uses this function to count the number of books which contain
* a given word:
*
* {{{
* case class Book(title: String, words: String)
* val df: RDD[Book]
*
* case class Word(word: String)
* val allWords = df.explode('words) {
* case Row(words: String) => words.split(" ").map(Word(_))
* }
*
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
* }}}
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame

/////////////////////////////////////////////////////////////////////////////

/**
Expand Down
16 changes: 14 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.CharArrayWriter

import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConversions._

import com.fasterxml.jackson.core.JsonFactory
Expand All @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand All @@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}


/**
* Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
*/
Expand Down Expand Up @@ -292,6 +292,18 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}

override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes
val rowFunction =
f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, None, logicalPlan)
}


/////////////////////////////////////////////////////////////////////////////
// RDD API
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -114,7 +115,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten

override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////
override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you are at it, this is indented wrong now


override def head(n: Int): Array[Row] = err()

Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ class DataFrameSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
}

checkAnswer(
df2
.select('_1 as 'letter, 'number)
.groupBy('letter)
.agg('letter, countDistinct('number)),
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
)
}

test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
Expand Down