Skip to content

Commit

Permalink
[SPARK-4202][SQL] Simple DSL support for Scala UDF
Browse files Browse the repository at this point in the history
This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API.

For the following test snippet

```scala
case class KeyValue(key: Int, value: String)
val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD
def foo(a: Int, b: String) => a.toString + b
```

the newly introduced DSL enables the following syntax

```scala
import org.apache.spark.sql.catalyst.dsl._
testData.select(Star(None), foo.call('key, 'value) as 'result)
```

which is equivalent to

```scala
testData.registerTempTable("testData")
sqlContext.registerFunction("foo", foo)
sql("SELECT *, foo(key, value) AS result FROM testData")
```

Author: Cheng Lian <[email protected]>

Closes #3067 from liancheng/udf-dsl and squashes the following commits:

f132818 [Cheng Lian] Adds DSL support for Scala UDF
  • Loading branch information
liancheng authored and marmbrus committed Nov 3, 2014
1 parent 24544fb commit c238fb4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types.decimal.Decimal

import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -285,4 +286,62 @@ package object dsl {
def writeToFile(path: String) = WriteToFile(path, logicalPlan)
}
}

case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
}

// scalastyle:off
/** functionToUdfBuilder 1-22 were generated by this script
(1 to 22).map { x =>
val argTypes = Seq.fill(x)("_").mkString(", ")
s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)"
}
*/

implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)

implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
// scalastyle:on
}
17 changes: 13 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test._

/* Implicits */
import TestSQLContext._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.test.TestSQLContext._

class DslQuerySuite extends QueryTest {
import TestData._
import org.apache.spark.sql.TestData._

test("table scan") {
checkAnswer(
Expand Down Expand Up @@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest {
(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}

test("udf") {
val foo = (a: Int, b: String) => a.toString + b

checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select(Star(None), foo.call('key, 'value)).limit(3),
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
)
}
}

0 comments on commit c238fb4

Please sign in to comment.