Skip to content

Commit

Permalink
Converts types of values based on defined schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 30, 2014
1 parent 4ceeb66 commit c712fbf
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging {
}

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* It is only used by pyspark.sql.
* TODO: Support more Python types.
*/
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
Expand Down
58 changes: 23 additions & 35 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,7 @@ def __repr__(self):
class FloatType(object):
"""Spark SQL FloatType
For now, please use L{DoubleType} instead of using L{FloatType}.
Because query evaluation is done in Scala, java.lang.Double will be be used
for Python float numbers. Because the underlying JVM type of FloatType is
java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if FloatType (Python) is used.
The data type representing single precision floating-point values.
"""
__metaclass__ = PrimitiveTypeSingleton
Expand All @@ -128,12 +123,7 @@ def __repr__(self):
class ByteType(object):
"""Spark SQL ByteType
For now, please use L{IntegerType} instead of using L{ByteType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ByteType is
java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if ByteType (Python) is used.
The data type representing int values with 1 singed byte.
"""
__metaclass__ = PrimitiveTypeSingleton
Expand Down Expand Up @@ -170,12 +160,7 @@ def __repr__(self):
class ShortType(object):
"""Spark SQL ShortType
For now, please use L{IntegerType} instead of using L{ShortType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ShortType is
java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if ShortType (Python) is used.
The data type representing int values with 2 signed bytes.
"""
__metaclass__ = PrimitiveTypeSingleton
Expand All @@ -198,7 +183,6 @@ def __init__(self, elementType, containsNull=False):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
:return:
>>> ArrayType(StringType) == ArrayType(StringType, False)
True
Expand Down Expand Up @@ -238,7 +222,6 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
:param keyType: the data type of keys.
:param valueType: the data type of values.
:param valueContainsNull: indicates whether values contains null values.
:return:
>>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
True
Expand Down Expand Up @@ -279,7 +262,6 @@ def __init__(self, name, dataType, nullable):
:param name: the name of this field.
:param dataType: the data type of this field.
:param nullable: indicates whether values of this field can be null.
:return:
>>> StructField("f1", StringType, True) == StructField("f1", StringType, True)
True
Expand Down Expand Up @@ -314,8 +296,6 @@ class StructType(object):
"""
def __init__(self, fields):
"""Creates a StructType
:param fields:
:return:
>>> struct1 = StructType([StructField("f1", StringType, True)])
>>> struct2 = StructType([StructField("f1", StringType, True)])
Expand All @@ -342,11 +322,7 @@ def __ne__(self, other):


def _parse_datatype_list(datatype_list_string):
"""Parses a list of comma separated data types.
:param datatype_list_string:
:return:
"""
"""Parses a list of comma separated data types."""
index = 0
datatype_list = []
start = 0
Expand All @@ -372,9 +348,6 @@ def _parse_datatype_list(datatype_list_string):
def _parse_datatype_string(datatype_string):
"""Parses the given data type string.
:param datatype_string:
:return:
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
... python_datatype = _parse_datatype_string(scala_datatype.toString())
Expand Down Expand Up @@ -582,9 +555,6 @@ def inferSchema(self, rdd):

def applySchema(self, rdd, schema):
"""Applies the given schema to the given RDD of L{dict}s.
:param rdd:
:param schema:
:return:
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
Expand All @@ -594,9 +564,27 @@ def applySchema(self, rdd, schema):
>>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}]
True
>>> from datetime import datetime
>>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0,
... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2},
... "list": [1, 2, 3]}])
>>> schema = StructType([
... StructField("byte", ByteType(), False),
... StructField("short", ShortType(), False),
... StructField("float", FloatType(), False),
... StructField("time", TimestampType(), False),
... StructField("map", MapType(StringType(), IntegerType(), False), False),
... StructField("struct", StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
>>> srdd = sqlCtx.applySchema(rdd, schema).map(
... lambda x: (
... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null))
>>> srdd.collect()[0]
(127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
"""
jrdd = self._pythonToJavaMap(rdd._jrdd)
srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__())
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__())
return SchemaRDD(srdd, self)

def registerRDDAsTable(self, rdd, tableName):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,13 @@ case class StructType(fields: Seq[StructField]) extends DataType {
*/
lazy val fieldNames: Seq[String] = fields.map(_.name)
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet

private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
fields.find(f => f.name == name).getOrElse(
nameToField.get(name).getOrElse(
throw new IllegalArgumentException(s"Field ${name} does not exist."))
}

Expand All @@ -333,6 +333,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
// Preserve the original order of fields.
StructType(fields.filter(f => names.contains(f.name)))
}

Expand Down
160 changes: 103 additions & 57 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(this, logicalPlan)
}

/**
* Parses the data type in our internal string representation. The data type string should
* have the same format as the one generated by `toString` in scala.
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
val parser = org.apache.spark.sql.catalyst.types.DataType
parser(dataTypeString)
}

/**
* Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
*/
private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = {
val schema = parseDataType(schemaString).asInstanceOf[StructType]
val rowRdd = rdd.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
}
}
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd)))
}

/**
* Loads a Parquet file, returning the result as a [[SchemaRDD]].
*
Expand Down Expand Up @@ -438,6 +415,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._

def typeOfComplexValue: PartialFunction[Any, DataType] = {
case c: java.util.Calendar => TimestampType
case c: java.util.List[_] =>
Expand All @@ -453,48 +431,116 @@ class SQLContext(@transient val sparkContext: SparkContext)
def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue

val firstRow = rdd.first()
val schema = StructType(
firstRow.map { case (fieldName, obj) =>
StructField(fieldName, typeOfObject(obj), true)
}.toSeq)

def needTransform(obj: Any): Boolean = obj match {
case c: java.util.List[_] => true
case c: java.util.Map[_, _] => true
case c if c.getClass.isArray => true
case c: java.util.Calendar => true
case c => false
val fields = firstRow.map {
case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true)
}.toSeq

applySchemaToPythonRDD(rdd, StructType(fields))
}

/**
* Parses the data type in our internal string representation. The data type string should
* have the same format as the one generated by `toString` in scala.
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
val parser = org.apache.spark.sql.catalyst.types.DataType
parser(dataTypeString)
}

/**
* Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
*/
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Map[String, _]],
schemaString: String): SchemaRDD = {
val schema = parseDataType(schemaString).asInstanceOf[StructType]
applySchemaToPythonRDD(rdd, schema)
}

/**
* Apply a schema defined by the schema to an RDD. It is only used by PySpark.
*/
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Map[String, _]],
schema: StructType): SchemaRDD = {
import scala.collection.JavaConversions._
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}

def needsConversion(dataType: DataType): Boolean = dataType match {
case ByteType => true
case ShortType => true
case FloatType => true
case TimestampType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
case other => false
}

// convert JList, JArray into Seq, convert JMap into Map
// convert Calendar into Timestamp
def transform(obj: Any): Any = obj match {
case c: java.util.List[_] => c.map(transform).toSeq
case c: java.util.Map[_, _] => c.map {
case (key, value) => (key, transform(value))
}.toMap
case c if c.getClass.isArray =>
c.asInstanceOf[Array[_]].map(transform).toSeq
case c: java.util.Calendar =>
new java.sql.Timestamp(c.getTime().getTime())
case c => c
// Converts value to the type specified by the data type.
// Because Python does not have data types for TimestampType, FloatType, ShortType, and
// ByteType, we need to explicitly convert values in columns of these data types to the desired
// JVM data types.
def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
// TODO: We should check nullable
case (null, _) => null

case (c: java.util.List[_], ArrayType(elementType, _)) =>
val converted = c.map { e => convert(e, elementType)}
JListWrapper(converted)

case (c: java.util.Map[_, _], struct: StructType) =>
val row = new GenericMutableRow(struct.fields.length)
struct.fields.zipWithIndex.foreach {
case (field, i) =>
val value = convert(c.get(field.name), field.dataType)
row.update(i, value)
}
row

case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
val converted = c.map {
case (key, value) =>
(convert(key, keyType), convert(value, valueType))
}
JMapWrapper(converted)

case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType))
converted: Seq[Any]

case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime())
case (c: Int, ByteType) => c.toByte
case (c: Int, ShortType) => c.toShort
case (c: Double, FloatType) => c.toFloat

case (c, _) => c
}

val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) })
} else {
rdd
}

val need = firstRow.exists { case (key, value) => needTransform(value) }
val transformed = if (need) {
rdd.mapPartitions { iter =>
iter.map {
m => m.map {case (key, value) => (key, transform(value))}
val rowRdd = convertedRdd.mapPartitions { iter =>
val row = new GenericMutableRow(schema.fields.length)
val fieldsWithIndex = schema.fields.zipWithIndex
iter.map { m =>
// We cannot use m.values because the order of values returned by m.values may not
// match fields order.
fieldsWithIndex.foreach {
case (field, i) =>
val value =
m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull
row.update(i, value)
}
}
} else rdd

val rowRdd = transformed.mapPartitions { iter =>
iter.map { map =>
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
row: Row
}
}

new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd)))
}

}
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.util.{Map => JMap, List => JList, Set => JSet}
import java.util.{Map => JMap, List => JList}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -380,6 +380,8 @@ class SchemaRDD(
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
import scala.collection.Map

def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
Expand Down

0 comments on commit c712fbf

Please sign in to comment.