Skip to content

Commit

Permalink
add Python UDT
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 3, 2014
1 parent ebd6480 commit b7f666d
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 8 deletions.
220 changes: 216 additions & 4 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from array import array
from operator import itemgetter
from itertools import imap
import importlib

from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
Expand All @@ -52,7 +53,7 @@
__all__ = [
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType", "UserDefinedType",
"SQLContext", "HiveContext", "SchemaRDD", "Row"]


Expand Down Expand Up @@ -408,6 +409,70 @@ def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])


class UserDefinedType(DataType):
"""
:: WARN: Spark Internal Use Only ::
SQL User-Defined Type (UDT).
"""

@classmethod
def sqlType(self):
"""
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement sqlType().")

@classmethod
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement serialize().")

@classmethod
def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
raise NotImplementedError("UDT must implement deserialize().")

@classmethod
def module(cls):
"""
The Python module of the UDT.
"""
raise NotImplementedError("UDT must implement module().")

@classmethod
def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
raise NotImplementedError("UDT must have a paired Scala UDT.")

@classmethod
def json(cls):
return json.dumps(cls.jsonValue(), separators=(',', ':'), sort_keys=True)

@classmethod
def jsonValue(cls):
schema = {
"type": "udt",
"pyModule": cls.module(),
"pyClass": cls.__name__}
if cls.scalaUDT() is not None:
schema['class'] = cls.scalaUDT()
return schema

@classmethod
def fromJson(cls, json):
pyModule = json['pyModule']
pyClass = json['pyClass']
m = importlib.import_module(pyModule)
UDT = getattr(m, pyClass)
return UDT()


_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
Expand Down Expand Up @@ -460,6 +525,13 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
>>> from pyspark.tests import ExamplePointUDT
>>> check_datatype(ExamplePointUDT())
True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
True
"""
return _parse_datatype_json_value(json.loads(json_string))

Expand All @@ -479,7 +551,13 @@ def _parse_datatype_json_value(json_value):
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)
tpe = json_value["type"]
if tpe in _all_complex_types:
return _all_complex_types[tpe].fromJson(json_value)
elif tpe == 'udt':
return UserDefinedType.fromJson(json_value)
else:
raise ValueError("not supported type: %s" % tpe)


# Mapping Python types to Spark SQL DataType
Expand All @@ -499,10 +577,19 @@ def _parse_datatype_json_value(json_value):


def _infer_type(obj):
"""Infer the DataType from obj"""
"""Infer the DataType from obj
>>> from pyspark.tests import ExamplePoint
>>> p = ExamplePoint(1.0, 2.0)
>>> _infer_type(p)
ExamplePointUDT
"""
if obj is None:
raise ValueError("Can not infer type for None")

if hasattr(obj, '__UDT__'):
return obj.__UDT__

dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
Expand Down Expand Up @@ -547,9 +634,94 @@ def _infer_schema(row):
fields = [StructField(k, _infer_type(v), True) for k, v in items]
return StructType(fields)

def _need_python_to_sql_conversion(dataType):
"""
Checks whether we need python to sql conversion for the given type.
For now, only UDTs need this conversion.
>>> _need_python_to_sql_conversion(DoubleType())
False
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
False
>>> from pyspark.tests import ExamplePointUDT
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
>>> _need_python_to_sql_conversion(schema1)
True
>>> schema2 = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> _need_python_to_sql_conversion(schema2)
True
"""
if isinstance(dataType, StructType):
return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
elif isinstance(dataType, ArrayType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_python_to_sql_conversion(dataType.keyType) or \
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
else:
return False

def _python_to_sql_converter(dataType):
"""
Returns a converter that converts a Python object into a SQL datum for the given type.
>>> conv = _python_to_sql_converter(DoubleType())
>>> conv(1.0)
1.0
>>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
>>> conv([1.0, 2.0])
[1.0, 2.0]
>>> from pyspark.tests import ExamplePointUDT, ExamplePoint
>>> conv = _python_to_sql_converter(ExamplePointUDT())
>>> conv(ExamplePoint(1.0, 2.0))
[1.0, 2.0]
>>> schema = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> conv = _python_to_sql_converter(schema)
>>> conv((1.0, ExamplePoint(1.0, 2.0)))
(1.0, [1.0, 2.0])
"""
if not _need_python_to_sql_conversion(dataType):
return lambda x: x

if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
converters = map(_python_to_sql_converter, types)
def converter(obj):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
return tuple(c(v) for c, v in zip(converters, obj))
else:
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
return lambda a: [element_converter(v) for v in a]
elif isinstance(dataType, MapType):
key_converter = _python_to_sql_converter(dataType.keyType)
value_converter = _python_to_sql_converter(dataType.valueType)
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
elif isinstance(dataType, UserDefinedType):
return lambda obj: dataType.serialize(obj)
else:
raise ValueError("Unexpected type %r" % dataType)

def _create_converter(obj, dataType):
"""Create an converter to drop the names of fields in obj """
"""Create an converter to drop the names of fields in obj"""
if isinstance(dataType, ArrayType):
conv = _create_converter(obj[0], dataType.elementType)
return lambda row: map(conv, row)
Expand Down Expand Up @@ -780,6 +952,10 @@ def _verify_type(obj, dataType):
if obj is None:
return

if isinstance(dataType, UserDefinedType):
# TODO: check UDT
return

_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType

Expand Down Expand Up @@ -854,6 +1030,8 @@ def _has_struct_or_date(dt):
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
elif isinstance(dt, UserDefinedType):
return True
return False


Expand Down Expand Up @@ -924,6 +1102,9 @@ def Dict(d):
elif isinstance(dataType, DateType):
return datetime.date

elif isinstance(dataType, UserDefinedType):
return lambda datum: dataType.deserialize(datum)

elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)

Expand Down Expand Up @@ -1184,6 +1365,10 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)

# convert python objects to sql data
converter = _python_to_sql_converter(schema)
rdd = rdd.map(converter)

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
Expand Down Expand Up @@ -1436,6 +1621,33 @@ def hql(self, hqlQuery):

class LocalHiveContext(HiveContext):

"""Starts up an instance of hive where metadata is stored locally.
An in-process metadata data is created with data stored in ./metadata.
Warehouse data is stored in in ./warehouse.
# >>> import os
# >>> hiveCtx = LocalHiveContext(sc)
# >>> try:
# ... supress = hiveCtx.sql("DROP TABLE src")
# ... except Exception:
# ... pass
# >>> kv1 = os.path.join(os.environ["SPARK_HOME"],
# ... 'examples/src/main/resources/kv1.txt')
# >>> supress = hiveCtx.sql(
# ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
# >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src"
# ... % kv1)
# >>> results = hiveCtx.sql("FROM src SELECT value"
# ... ).map(lambda r: int(r.value.split('_')[1]))
# >>> num = results.count()
# >>> reduce_sum = results.reduce(lambda x, y: x + y)
# >>> num
# 500
# >>> reduce_sum
# 130091
"""

def __init__(self, sparkContext, sqlContext=None):
HiveContext.__init__(self, sparkContext, sqlContext)
warnings.warn("LocalHiveContext is deprecated. "
Expand Down
50 changes: 49 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType
from pyspark import shuffle

_have_scipy = False
Expand Down Expand Up @@ -791,6 +792,53 @@ def test_convert_row_to_dict(self):
self.assertEqual(1, row.asDict()["la"])


class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
>>> schema = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> schema
StructType(List(StructField(label,DoubleType,false),StructField(point,ExamplePointUDT,false)))
"""

@classmethod
def sqlType(self):
return ArrayType(DoubleType(), False)

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])

@classmethod
def module(cls):
return 'pyspark.tests'

@classmethod
def scalaUDT(cls):
return 'org.apache.spark.sql.test.ExamplePointUDT'


class ExamplePoint:
"""
An example class to demonstrate UDT in Scala, Java, and Python.
"""

__UDT__ = ExamplePointUDT()

def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return "ExamplePoint(%s,%s)" % (self.x, self.y)

def __str__(self):
return "(%s,%s)" % (self.x, self.y)


class InputFormatTests(ReusedPySparkTestCase):

@classmethod
Expand Down
Loading

0 comments on commit b7f666d

Please sign in to comment.