diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..45ffd0756125d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,8 +20,412 @@ from py4j.protocol import Py4JError -__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +__all__ = [ + "StringType", "BinaryType", "BooleanType", "DecimalType", "DoubleType", + "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", + "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +class PrimitiveTypeSingleton(type): + _instances = {} + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + +class StringType(object): + """Spark SQL StringType + + The data type representing string values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "StringType" + +class BinaryType(object): + """Spark SQL BinaryType + + The data type representing bytes values and bytearray values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "BinaryType" + +class BooleanType(object): + """Spark SQL BooleanType + + The data type representing bool values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "BooleanType" + +class TimestampType(object): + """Spark SQL TimestampType + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "TimestampType" + +class DecimalType(object): + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "DecimalType" + +class DoubleType(object): + """Spark SQL DoubleType + + The data type representing float values. Because a float value + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "DoubleType" + +class FloatType(object): + """Spark SQL FloatType + + For PySpark, please use L{DoubleType} instead of using L{FloatType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "FloatType" + +class ByteType(object): + """Spark SQL ByteType + + For PySpark, please use L{IntegerType} instead of using L{ByteType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "ByteType" + +class IntegerType(object): + """Spark SQL IntegerType + + The data type representing int values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "IntegerType" + +class LongType(object): + """Spark SQL LongType + + The data type representing long values. If the any value is beyond the range of + [-9223372036854775808, 9223372036854775807], please use DecimalType. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "LongType" + +class ShortType(object): + """Spark SQL ShortType + + For PySpark, please use L{IntegerType} instead of using L{ShortType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "ShortType" + +class ArrayType(object): + """Spark SQL ArrayType + + The data type representing list values. + + """ + def __init__(self, elementType, containsNull): + """ + Create an ArrayType + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains null values. + :return: + >>> ArrayType(StringType, True) == ArrayType(StringType, False) + False + >>> ArrayType(StringType, True) == ArrayType(StringType, True) + True + """ + self.elementType = elementType + self.containsNull = containsNull + + def _get_scala_type_string(self): + return "ArrayType(" + self.elementType._get_scala_type_string() + "," + \ + str(self.containsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.elementType == other.elementType and \ + self.containsNull == other.containsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class MapType(object): + """Spark SQL MapType + + The data type representing dict values. + + """ + def __init__(self, keyType, valueType): + """ + Create a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :return: + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType) + True + >>> MapType(StringType, IntegerType) == MapType(StringType, FloatType) + False + """ + self.keyType = keyType + self.valueType = valueType + + def _get_scala_type_string(self): + return "MapType(" + self.keyType._get_scala_type_string() + "," + \ + self.valueType._get_scala_type_string() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.keyType == other.keyType and \ + self.valueType == other.valueType) + + def __ne__(self, other): + return not self.__eq__(other) + +class StructField(object): + """Spark SQL StructField + + Represents a field in a StructType. + + """ + def __init__(self, name, dataType, nullable): + """ + Create a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field can be null. + :return: + >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + True + >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + + def _get_scala_type_string(self): + return "StructField(" + self.name + "," + \ + self.dataType._get_scala_type_string() + "," + \ + str(self.nullable).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.name == other.name and \ + self.dataType == other.dataType and \ + self.nullable == other.nullable) + + def __ne__(self, other): + return not self.__eq__(other) + +class StructType(object): + """Spark SQL StructType + + The data type representing tuple values. + + """ + def __init__(self, fields): + """ + Create a StructType + :param fields: + :return: + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def _get_scala_type_string(self): + return "StructType(List(" + \ + ",".join([field._get_scala_type_string() for field in self.fields]) + "))" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.fields == other.fields) + + def __ne__(self, other): + return not self.__eq__(other) + +def _parse_datatype_list(datatype_list_string): + index = 0 + datatype_list = [] + start = 0 + depth = 0 + while index < len(datatype_list_string): + if depth == 0 and datatype_list_string[index] == ",": + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + start = index + 1 + elif datatype_list_string[index] == "(": + depth += 1 + elif datatype_list_string[index] == ")": + depth -= 1 + + index += 1 + + # Handle the last data type + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + return datatype_list + +def _parse_datatype_string(datatype_string): + """Parses the given data type string. + :param datatype_string: + :return: + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype._get_scala_type_string()) + ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... return datatype == python_datatype + >>> check_datatype(StringType()) + True + >>> check_datatype(BinaryType()) + True + >>> check_datatype(BooleanType()) + True + >>> check_datatype(TimestampType()) + True + >>> check_datatype(DecimalType()) + True + >>> check_datatype(DoubleType()) + True + >>> check_datatype(FloatType()) + True + >>> check_datatype(ByteType()) + True + >>> check_datatype(IntegerType()) + True + >>> check_datatype(LongType()) + True + >>> check_datatype(ShortType()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False)]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, complex_arraytype) + >>> check_datatype(complex_maptype) + True + """ + left_bracket_index = datatype_string.find("(") + if left_bracket_index == -1: + # It is a primitive type. + left_bracket_index = len(datatype_string) + type_or_field = datatype_string[:left_bracket_index] + rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() + if type_or_field == "StringType": + return StringType() + elif type_or_field == "BinaryType": + return BinaryType() + elif type_or_field == "BooleanType": + return BooleanType() + elif type_or_field == "TimestampType": + return TimestampType() + elif type_or_field == "DecimalType": + return DecimalType() + elif type_or_field == "DoubleType": + return DoubleType() + elif type_or_field == "FloatType": + return FloatType() + elif type_or_field == "ByteType": + return ByteType() + elif type_or_field == "IntegerType": + return IntegerType() + elif type_or_field == "LongType": + return LongType() + elif type_or_field == "ShortType": + return ShortType() + elif type_or_field == "ArrayType": + last_comma_index = rest_part.rfind(",") + containsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + containsNull = False + elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": + keyType, valueType = _parse_datatype_list(rest_part.strip()) + return MapType(keyType, valueType) + elif type_or_field == "StructField": + first_comma_index = rest_part.find(",") + name = rest_part[:first_comma_index].strip() + last_comma_index = rest_part.rfind(",") + nullable = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + nullable = False + dataType = _parse_datatype_string( + rest_part[first_comma_index+1:last_comma_index].strip()) + return StructField(name, dataType, nullable) + elif type_or_field == "StructType": + # rest_part should be in the format like + # List(StructField(field1,IntegerType,false)). + field_list_string = rest_part[rest_part.find("(")+1:-1] + fields = _parse_datatype_list(field_list_string) + return StructType(fields) class SQLContext: """Main entry point for SparkSQL functionality. @@ -107,6 +511,24 @@ def inferSchema(self, rdd): srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) return SchemaRDD(srdd, self) + def applySchema(self, rdd, schema): + """Applies the given schema to the given RDD of L{dict}s. + :param rdd: + :param schema: + :return: + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT * from table1") + >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, + ... {"field1" : 3, "field2": "row3"}] + True + """ + jrdd = self._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema._get_scala_type_string()) + return SchemaRDD(srdd, self) + def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -137,10 +559,11 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): - """Loads a text file storing one JSON object per line, - returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonFile(self, path, schema = None): + """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -149,8 +572,8 @@ def jsonFile(self, path): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -158,16 +581,45 @@ def jsonFile(self, path): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ - jschema_rdd = self._ssql_ctx.jsonFile(path) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonFile(path) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd): - """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonRDD(self, rdd, schema = None): + """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - >>> srdd = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -175,6 +627,29 @@ def jsonRDD(self, rdd): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ def func(split, iterator): for x in iterator: @@ -184,7 +659,11 @@ def func(split, iterator): keyed = PipelinedRDD(rdd, func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) def sql(self, sqlQuery): @@ -387,6 +866,9 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schema(self): + return _parse_datatype_string(self._jschema_rdd.schema().toString()) + def schemaString(self): """Returns the output schema in the tree format.""" return self._jschema_rdd.schemaString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 11880a80443f3..e358f00f8d852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -105,6 +105,28 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, logicalPlan) } + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generate by `toString` in scala. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + val rowRdd = rdd.mapPartitions { iter => + iter.map { map => + new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + } + } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) + } + /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. *