Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3572] [SQL] Internal API for User-Defined Types #3063

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
105c5a3
Adding UserDefinedType to SQL, not done yet.
jkbradley Oct 3, 2014
0eaeb81
Still working on UDTs
jkbradley Oct 6, 2014
19b2f60
still working on UDTs
jkbradley Oct 6, 2014
982c035
still working on UDTs
jkbradley Oct 7, 2014
53de70f
more udts...
jkbradley Oct 7, 2014
8bebf24
commented out convertRowToScala for debugging
jkbradley Oct 7, 2014
273ac96
basic UDT is working, but deserialization has yet to be done
jkbradley Oct 8, 2014
39f8707
removed old udt suite
jkbradley Oct 8, 2014
04303c9
udts
jkbradley Oct 9, 2014
50f9726
udts
jkbradley Oct 9, 2014
893ee4c
udt finallly working
jkbradley Oct 9, 2014
964b32e
some cleanups
jkbradley Oct 9, 2014
fea04af
more cleanups
jkbradley Oct 9, 2014
b226b9e
Changing UDT to annotation
jkbradley Oct 10, 2014
3579035
udt annotation now working
jkbradley Oct 10, 2014
2f40c02
renamed UDT types
jkbradley Oct 10, 2014
e1f7b9c
blah
jkbradley Oct 10, 2014
34a5831
Added MLlib dependency on SQL.
jkbradley Oct 10, 2014
cd60cb4
Trying to get other SQL tests to run
jkbradley Oct 21, 2014
dff99d6
Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs
jkbradley Oct 22, 2014
85872f6
Allow schema calculation to be lazy, but ensure its available on exec…
marmbrus Oct 23, 2014
f025035
Cleanups before PR. Added new tests
jkbradley Oct 24, 2014
51e5282
fixed 1 test
jkbradley Oct 24, 2014
63626a4
Updated ScalaReflectionsSuite per @marmbrus suggestions
jkbradley Oct 24, 2014
759af7a
Added more doc to UserDefineType
jkbradley Oct 27, 2014
db16139
Added more doc for UserDefinedType. Removed unused code in Suite
jkbradley Oct 28, 2014
cfbc321
support UDT in parquet
mengxr Oct 28, 2014
3143ac3
remove unnecessary changes
mengxr Oct 28, 2014
87264a5
remove debug code
mengxr Oct 28, 2014
4500d8a
update example code
mengxr Oct 28, 2014
b028675
allow any type in UDT
mengxr Oct 28, 2014
7f29656
Moved udt case to top of all matches. Small cleanups
jkbradley Oct 28, 2014
8b242ea
Fixed merge error after last merge. Note: Last merge commit also rem…
jkbradley Oct 29, 2014
8de957c
Modified UserDefinedType to store Java class of user type so that reg…
jkbradley Oct 30, 2014
fa86b20
Removed Java UserDefinedType, and made UDTs private[spark] for now
jkbradley Oct 31, 2014
20630bc
fixed scalastyle
jkbradley Oct 31, 2014
6fddc1c
Made MyLabeledPoint into a Java Bean
jkbradley Oct 31, 2014
a571bb6
Removed old UDT code (registry and Java UDTs). Cleaned up other code…
jkbradley Oct 31, 2014
d063380
Cleaned up Java UDT Suite, and added warning about element ordering w…
jkbradley Oct 31, 2014
30ce5b2
updates based on code review
jkbradley Nov 2, 2014
5817b2b
style edits
jkbradley Nov 2, 2014
e13cd8a
Removed Vector UDTs
jkbradley Nov 2, 2014
f3c72fe
Fixing merge
jkbradley Nov 2, 2014
15c10a6
Merge remote-tracking branch 'upstream/master' into sql-udt2
jkbradley Nov 2, 2014
e369b91
Merge remote-tracking branch 'origin/master' into udts
marmbrus Nov 3, 2014
6cc434d
Recursively convert rows.
marmbrus Nov 3, 2014
46a3aee
Slightly easier to read test output.
marmbrus Nov 3, 2014
7ccfc0d
remove println
marmbrus Nov 3, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
Expand All @@ -35,25 +37,46 @@ object ScalaReflection {

case class Schema(dataType: DataType, nullable: Boolean)

/** Converts Scala objects to catalyst rows / types */
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.map(convertToCatalyst).orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case d: BigDecimal => Decimal(d)
case other => other
/**
* Converts Scala objects to catalyst rows / types.
* Note: This is always called after schemaFor has been called.
* This ordering is important for UDT registration.
*/
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
case (p: Product, structType: StructType) =>
new GenericRow(
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (other, _) => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any): Any = a match {
case s: Seq[_] => s.map(convertToScala)
case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
case d: Decimal => d.toBigDecimal
case other => other
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (r: Row, s: StructType) => convertRowToScala(r, s)
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (other, _) => other
}

def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
def convertRowToScala(r: Row, schema: StructType): Row = {
new GenericRow(
r.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}

/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
Expand All @@ -65,52 +88,64 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
def schemaFor(tpe: `Type`): Schema = {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
// Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.annotation;

import java.lang.annotation.*;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.types.UserDefinedType;

/**
* ::DeveloperApi::
* A user-defined type which can be automatically recognized by a SQLContext and registered.
*
* WARNING: This annotation will only work if both Java and Scala reflection return the same class
* names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
* is enclosed in an object (a singleton).
*
* WARNING: UDTs are currently only supported from Scala.
*/
// TODO: Should I used @Documented ?
@DeveloperApi
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SQLUserDefinedType {

/**
* Returns an instance of the UserDefinedType which can serialize and deserialize the user
* class to and from Catalyst built-in types.
*/
Class<? extends UserDefinedType<?> > udt();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner

/**
* User-defined function.
* @param dataType Return type of function.
*/
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {

Expand Down Expand Up @@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
}
// scalastyle:on

ScalaReflection.convertToCatalyst(result)
ScalaReflection.convertToCatalyst(result, dataType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.types.decimal._

object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
Expand Down Expand Up @@ -67,6 +68,11 @@ object DataType {
("fields", JArray(fields)),
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))

case JSortedObject(
("class", JString(udtClass)),
("type", JString("udt"))) =>
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}

private def parseStructField(json: JValue): StructField = json match {
Expand Down Expand Up @@ -342,6 +348,7 @@ object FractionalType {
case _ => false
}
}

abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
Expand Down Expand Up @@ -565,3 +572,45 @@ case class MapType(
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
}

/**
* ::DeveloperApi::
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
* a SchemaRDD which has class X in the schema.
*
* For SparkSQL to recognize UDTs, the UDT must be annotated with
* [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]].
*
* The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {

/** Underlying storage type for this UDT */
def sqlType: DataType

/**
* Convert the user type to a SQL datum
*
* TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
* where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any

/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType

override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
("class" -> this.getClass.getName)
}

/**
* Class object for the UserType
*/
def userClass: java.lang.Class[UserType]
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}

import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._

case class PrimitiveData(
Expand Down Expand Up @@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
assert(convertToCatalyst(data) === convertedData)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}

test("convert Option[Product] to catalyst") {
val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
assert(convertToCatalyst(data) === convertedData)
val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
Some(primitiveData))
val dataType = schemaFor[OptionalData].dataType
val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true,
Row(1, 1, 1, 1, 1, 1, true))
assert(convertToCatalyst(data, dataType) === convertedData)
}
}
Loading