diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d8453fb184a3..f551a59ee3fe8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert an RDD of serialized Python dictionaries to Scala Maps + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * It is only used by pyspark.sql. * TODO: Support more Python types. */ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3f2f4dad49a83..13f0ed4e35490 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -111,12 +111,7 @@ def __repr__(self): class FloatType(object): """Spark SQL FloatType - For now, please use L{DoubleType} instead of using L{FloatType}. - Because query evaluation is done in Scala, java.lang.Double will be be used - for Python float numbers. Because the underlying JVM type of FloatType is - java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if FloatType (Python) is used. + The data type representing single precision floating-point values. """ __metaclass__ = PrimitiveTypeSingleton @@ -128,12 +123,7 @@ def __repr__(self): class ByteType(object): """Spark SQL ByteType - For now, please use L{IntegerType} instead of using L{ByteType}. - Because query evaluation is done in Scala, java.lang.Integer will be be used - for Python int numbers. Because the underlying JVM type of ByteType is - java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if ByteType (Python) is used. + The data type representing int values with 1 singed byte. """ __metaclass__ = PrimitiveTypeSingleton @@ -170,12 +160,7 @@ def __repr__(self): class ShortType(object): """Spark SQL ShortType - For now, please use L{IntegerType} instead of using L{ShortType}. - Because query evaluation is done in Scala, java.lang.Integer will be be used - for Python int numbers. Because the underlying JVM type of ShortType is - java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if ShortType (Python) is used. + The data type representing int values with 2 signed bytes. """ __metaclass__ = PrimitiveTypeSingleton @@ -198,7 +183,6 @@ def __init__(self, elementType, containsNull=False): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - :return: >>> ArrayType(StringType) == ArrayType(StringType, False) True @@ -238,7 +222,6 @@ def __init__(self, keyType, valueType, valueContainsNull=True): :param keyType: the data type of keys. :param valueType: the data type of values. :param valueContainsNull: indicates whether values contains null values. - :return: >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) True @@ -279,7 +262,6 @@ def __init__(self, name, dataType, nullable): :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. - :return: >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) True @@ -314,8 +296,6 @@ class StructType(object): """ def __init__(self, fields): """Creates a StructType - :param fields: - :return: >>> struct1 = StructType([StructField("f1", StringType, True)]) >>> struct2 = StructType([StructField("f1", StringType, True)]) @@ -342,11 +322,7 @@ def __ne__(self, other): def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types. - - :param datatype_list_string: - :return: - """ + """Parses a list of comma separated data types.""" index = 0 datatype_list = [] start = 0 @@ -372,9 +348,6 @@ def _parse_datatype_list(datatype_list_string): def _parse_datatype_string(datatype_string): """Parses the given data type string. - :param datatype_string: - :return: - >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) @@ -582,9 +555,6 @@ def inferSchema(self, rdd): def applySchema(self, rdd, schema): """Applies the given schema to the given RDD of L{dict}s. - :param rdd: - :param schema: - :return: >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) @@ -594,9 +564,27 @@ def applySchema(self, rdd, schema): >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, ... {"field1" : 3, "field2": "row3"}] True + >>> from datetime import datetime + >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, + ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, + ... "list": [1, 2, 3]}]) + >>> schema = StructType([ + ... StructField("byte", ByteType(), False), + ... StructField("short", ShortType(), False), + ... StructField("float", FloatType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> srdd = sqlCtx.applySchema(rdd, schema).map( + ... lambda x: ( + ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + >>> srdd.collect()[0] + (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) """ jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__()) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e6eb5a0744d16..ea7120022c51d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -313,13 +313,13 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ lazy val fieldNames: Seq[String] = fields.map(_.name) private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - fields.find(f => f.name == name).getOrElse( + nameToField.get(name).getOrElse( throw new IllegalArgumentException(s"Field ${name} does not exist.")) } @@ -333,6 +333,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } + // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cc2bf7059ca7a..61aa0882c476a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -125,29 +125,6 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, logicalPlan) } - /** - * Parses the data type in our internal string representation. The data type string should - * have the same format as the one generated by `toString` in scala. - * It is only used by PySpark. - */ - private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) - } - - /** - * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. - */ - private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = { - val schema = parseDataType(schemaString).asInstanceOf[StructType] - val rowRdd = rdd.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row - } - } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) - } - /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * @@ -438,6 +415,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ + def typeOfComplexValue: PartialFunction[Any, DataType] = { case c: java.util.Calendar => TimestampType case c: java.util.List[_] => @@ -453,48 +431,116 @@ class SQLContext(@transient val sparkContext: SparkContext) def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue val firstRow = rdd.first() - val schema = StructType( - firstRow.map { case (fieldName, obj) => - StructField(fieldName, typeOfObject(obj), true) - }.toSeq) - - def needTransform(obj: Any): Boolean = obj match { - case c: java.util.List[_] => true - case c: java.util.Map[_, _] => true - case c if c.getClass.isArray => true - case c: java.util.Calendar => true - case c => false + val fields = firstRow.map { + case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) + }.toSeq + + applySchemaToPythonRDD(rdd, StructType(fields)) + } + + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schema: StructType): SchemaRDD = { + import scala.collection.JavaConversions._ + import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} + + def needsConversion(dataType: DataType): Boolean = dataType match { + case ByteType => true + case ShortType => true + case FloatType => true + case TimestampType => true + case ArrayType(_, _) => true + case MapType(_, _, _) => true + case StructType(_) => true + case other => false } - // convert JList, JArray into Seq, convert JMap into Map - // convert Calendar into Timestamp - def transform(obj: Any): Any = obj match { - case c: java.util.List[_] => c.map(transform).toSeq - case c: java.util.Map[_, _] => c.map { - case (key, value) => (key, transform(value)) - }.toMap - case c if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(transform).toSeq - case c: java.util.Calendar => - new java.sql.Timestamp(c.getTime().getTime()) - case c => c + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + val converted = c.map { e => convert(e, elementType)} + JListWrapper(converted) + + case (c: java.util.Map[_, _], struct: StructType) => + val row = new GenericMutableRow(struct.fields.length) + struct.fields.zipWithIndex.foreach { + case (field, i) => + val value = convert(c.get(field.name), field.dataType) + row.update(i, value) + } + row + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val converted = c.map { + case (key, value) => + (convert(key, keyType), convert(value, valueType)) + } + JMapWrapper(converted) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) + converted: Seq[Any] + + case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (c: Int, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Double, FloatType) => c.toFloat + + case (c, _) => c + } + + val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { + rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + } else { + rdd } - val need = firstRow.exists { case (key, value) => needTransform(value) } - val transformed = if (need) { - rdd.mapPartitions { iter => - iter.map { - m => m.map {case (key, value) => (key, transform(value))} + val rowRdd = convertedRdd.mapPartitions { iter => + val row = new GenericMutableRow(schema.fields.length) + val fieldsWithIndex = schema.fields.zipWithIndex + iter.map { m => + // We cannot use m.values because the order of values returned by m.values may not + // match fields order. + fieldsWithIndex.foreach { + case (field, i) => + val value = + m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull + row.update(i, value) } - } - } else rdd - val rowRdd = transformed.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + row: Row } } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 0940300a72983..2a79abb92d247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList, Set => JSet} +import java.util.{Map => JMap, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -380,6 +380,8 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + import scala.collection.Map + def toJava(obj: Any, dataType: DataType): Any = dataType match { case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) case array: ArrayType => obj match {