diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 98e41f8575679..7c2737661e581 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -39,6 +39,7 @@ from array import array from operator import itemgetter from itertools import imap +import importlib from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter @@ -52,7 +53,7 @@ __all__ = [ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", "UserDefinedType", "SQLContext", "HiveContext", "SchemaRDD", "Row"] @@ -408,6 +409,70 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def sqlType(self): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + @classmethod + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + @classmethod + def json(cls): + return json.dumps(cls.jsonValue(), separators=(',', ':'), sort_keys=True) + + @classmethod + def jsonValue(cls): + schema = { + "type": "udt", + "pyModule": cls.module(), + "pyClass": cls.__name__} + if cls.scalaUDT() is not None: + schema['class'] = cls.scalaUDT() + return schema + + @classmethod + def fromJson(cls, json): + pyModule = json['pyModule'] + pyClass = json['pyClass'] + m = importlib.import_module(pyModule) + UDT = getattr(m, pyClass) + return UDT() + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -460,6 +525,13 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> from pyspark.tests import ExamplePointUDT + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -479,7 +551,13 @@ def _parse_datatype_json_value(json_value): else: raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType @@ -499,10 +577,19 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> from pyspark.tests import ExamplePoint + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ if obj is None: raise ValueError("Can not infer type for None") + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() @@ -547,9 +634,94 @@ def _infer_schema(row): fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> from pyspark.tests import ExamplePointUDT + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> from pyspark.tests import ExamplePointUDT, ExamplePoint + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) def _create_converter(obj, dataType): - """Create an converter to drop the names of fields in obj """ + """Create an converter to drop the names of fields in obj""" if isinstance(dataType, ArrayType): conv = _create_converter(obj[0], dataType.elementType) return lambda row: map(conv, row) @@ -780,6 +952,10 @@ def _verify_type(obj, dataType): if obj is None: return + if isinstance(dataType, UserDefinedType): + # TODO: check UDT + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -854,6 +1030,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -924,6 +1102,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1184,6 +1365,10 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) @@ -1436,6 +1621,33 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): + """Starts up an instance of hive where metadata is stored locally. + + An in-process metadata data is created with data stored in ./metadata. + Warehouse data is stored in in ./warehouse. + + # >>> import os + # >>> hiveCtx = LocalHiveContext(sc) + # >>> try: + # ... supress = hiveCtx.sql("DROP TABLE src") + # ... except Exception: + # ... pass + # >>> kv1 = os.path.join(os.environ["SPARK_HOME"], + # ... 'examples/src/main/resources/kv1.txt') + # >>> supress = hiveCtx.sql( + # ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + # >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + # ... % kv1) + # >>> results = hiveCtx.sql("FROM src SELECT value" + # ... ).map(lambda r: int(r.value.split('_')[1])) + # >>> num = results.count() + # >>> reduce_sum = results.reduce(lambda x, y: x + y) + # >>> num + # 500 + # >>> reduce_sum + # 130091 + """ + def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b3a7..62cdc80fb6ab2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType from pyspark import shuffle _have_scipy = False @@ -791,6 +792,53 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["la"]) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> schema + StructType(List(StructField(label,DoubleType,false),StructField(point,ExamplePointUDT,false))) + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + class InputFormatTests(ReusedPySparkTestCase): @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cc5015ad3c013..7c9901f4fd331 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -71,6 +71,8 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), + ("pyClass", _), + ("pyModule", _), ("type", JString("udt"))) => Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } @@ -593,6 +595,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType + def pyUDT: (String, String) = (null, null) + /** * Convert the user type to a SQL datum * @@ -605,8 +609,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { def deserialize(datum: Any): UserType override private[sql] def jsonValue: JValue = { + val (pyModule, pyClass) = pyUDT ("type" -> "udt") ~ - ("class" -> this.getClass.getName) + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyClass) ~ + ("pyModule" -> pyModule) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9e61d18f7e926..bd8d1fc8ce775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -483,6 +483,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => true case other => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 401798e317e96..4a8178e0765bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -103,7 +103,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { override def equals(other: Any): Boolean = other match { case that: Row => (that canEqual this) && - row == that.row + row == that.row // Should this be row.equals(that.row)? case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 997669051ed07..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal // Pyrolite can handle Timestamp @@ -177,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 51dad54f1a3f3..6c1d4829de29d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -433,6 +433,9 @@ package object sql { @DeveloperApi val StructField = catalyst.types.StructField + @DeveloperApi + type UserDefinedType[T] = catalyst.types.UserDefinedType[T] + /** * Converts a logical plan into zero or more SparkPlans. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..a26a96485df0f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def pyUDT: (String, String) = ("pyspark.tests", "ExamplePointUDT") + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 1bc15146f0fe8..3fa4a7c6481d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.UserDefinedType - protected[sql] object DataTypeConversions { /**