Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3594] [PySpark] [SQL] take more rows to infer schema or sampling #2716

Closed
wants to merge 10 commits into from
196 changes: 128 additions & 68 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ def __eq__(self, other):
return self is other


class NullType(PrimitiveType):

"""Spark SQL NullType

The data type representing None, used for the types which has not
been inferred.
"""


class StringType(PrimitiveType):

"""Spark SQL StringType
Expand Down Expand Up @@ -331,7 +340,7 @@ class StructField(DataType):

"""

def __init__(self, name, dataType, nullable, metadata=None):
def __init__(self, name, dataType, nullable=True, metadata=None):
"""Creates a StructField
:param name: the name of this field.
:param dataType: the data type of this field.
Expand Down Expand Up @@ -484,6 +493,7 @@ def _parse_datatype_json_value(json_value):

# Mapping Python types to Spark SQL DataType
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: IntegerType,
long: LongType,
Expand All @@ -500,22 +510,22 @@ def _parse_datatype_json_value(json_value):

def _infer_type(obj):
"""Infer the DataType from obj"""
if obj is None:
raise ValueError("Can not infer type for None")

dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()

if isinstance(obj, dict):
if not obj:
raise ValueError("Can not infer type for empty dict")
key, value = obj.iteritems().next()
return MapType(_infer_type(key), _infer_type(value), True)
for key, value in obj.iteritems():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
else:
return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
if not obj:
raise ValueError("Can not infer type for empty list/array")
return ArrayType(_infer_type(obj[0]), True)
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
else:
return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
Expand Down Expand Up @@ -548,60 +558,93 @@ def _infer_schema(row):
return StructType(fields)


def _create_converter(obj, dataType):
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
return any(_has_nulltype(f.dataType) for f in dt.fields)
elif isinstance(dt, ArrayType):
return _has_nulltype((dt.elementType))
elif isinstance(dt, MapType):
return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
else:
return isinstance(dt, NullType)


def _merge_type(a, b):
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError("Can not merge type %s and %s" % (a, b))

# same type
if isinstance(a, StructType):
nfs = dict((f.name, f.dataType) for f in b.fields)
fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
if n not in names:
fields.append(StructField(n, nfs[n]))
return StructType(fields)

elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.elementType, b.elementType), True)

elif isinstance(a, MapType):
return MapType(_merge_type(a.keyType, b.keyType),
_merge_type(a.valueType, b.valueType),
True)
else:
return a


def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if isinstance(dataType, ArrayType):
conv = _create_converter(obj[0], dataType.elementType)
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)

elif isinstance(dataType, MapType):
value = obj.values()[0]
conv = _create_converter(value, dataType.valueType)
conv = _create_converter(dataType.valueType)
return lambda row: dict((k, conv(v)) for k, v in row.iteritems())

elif isinstance(dataType, NullType):
return lambda x: None

elif not isinstance(dataType, StructType):
return lambda x: x

# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]

def convert_struct(obj):
if obj is None:
return

if isinstance(obj, tuple):
if hasattr(obj, "fields"):
d = dict(zip(obj.fields, obj))
if hasattr(obj, "__FIELDS__"):
d = dict(zip(obj.__FIELDS__, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
d = dict(obj)
else:
raise ValueError("unexpected tuple: %s" % obj)

if isinstance(obj, dict):
conv = lambda o: tuple(o.get(n) for n in names)

elif isinstance(obj, tuple):
if hasattr(obj, "_fields"): # namedtuple
conv = tuple
elif hasattr(obj, "__FIELDS__"):
conv = tuple
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
conv = lambda o: tuple(v for k, v in o)
elif isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise ValueError("unexpected tuple")
raise ValueError("Unexpected obj: %s" % obj)

elif hasattr(obj, "__dict__"): # object
conv = lambda o: [o.__dict__.get(n, None) for n in names]
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])

if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
return conv

row = conv(obj)
convs = [_create_converter(v, f.dataType)
for v, f in zip(row, dataType.fields)]

def nested_conv(row):
return tuple(f(v) for f, v in zip(convs, conv(row)))

return nested_conv


def _drop_schema(rows, schema):
""" all the names of fields, becoming tuples"""
iterator = iter(rows)
row = iterator.next()
converter = _create_converter(row, schema)
yield converter(row)
for i in iterator:
yield converter(i)
return convert_struct


_BRACKETS = {'(': ')', '[': ']', '{': '}'}
Expand Down Expand Up @@ -713,7 +756,7 @@ def _infer_schema_type(obj, dataType):
return _infer_type(obj)

if not obj:
raise ValueError("Can not infer type from empty value")
return NullType()

if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
Expand Down Expand Up @@ -1049,18 +1092,20 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
returnType.json())

def inferSchema(self, rdd):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.

We peek at the first row of the RDD to determine the fields' names
and types. Nested collections are supported, which include array,
dict, list, Row, tuple, namedtuple, or object.
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
supported, which can include array, dict, list, Row, tuple,
namedtuple, or object.

All the rows in `rdd` should have the same type with the first one,
or it will cause runtime exceptions.
Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
Using top level dicts is deprecated, as dict is used to represent Maps.

Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
using dict is deprecated.
If a single column has multiple distinct inferred types, it may cause
runtime exceptions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When samplingRatio is specified, the schema is inferred by looking at the types of each row in the sampled dataset. Otherwise, the first 100 rows of the RDD are inspected. Nested collections are supported, which can include array, dict, list, Row, tuple, namedtuple, or object.

Each row could be L{pyspark.sql.Row} object or namedtuple or objects. Using top level dicts is deprecated, as this datatype is used to represent Maps.

If a single column has multiple distinct inferred types, it may cause runtime exceptions.


>>> rdd = sc.parallelize(
... [Row(field1=1, field2="row1"),
Expand Down Expand Up @@ -1097,8 +1142,23 @@ def inferSchema(self, rdd):
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")

schema = _infer_schema(first)
rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
if samplingRatio is None:
schema = _infer_schema(first)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(row))
if not _has_nulltype(schema):
break
else:
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
if samplingRatio > 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)

converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)

def applySchema(self, rdd, schema):
Expand Down Expand Up @@ -1219,16 +1279,16 @@ def parquetFile(self, path):
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)

def jsonFile(self, path, schema=None):
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
L{SchemaRDD}.

If the schema is provided, applies the given schema to this
JSON dataset.

Otherwise, it goes through the entire dataset once to determine
the schema.
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.

>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
Expand Down Expand Up @@ -1274,20 +1334,20 @@ def jsonFile(self, path, schema=None):
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def jsonRDD(self, rdd, schema=None):
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.

If the schema is provided, applies the given schema to this
JSON dataset.

Otherwise, it goes through the entire dataset once to determine
the schema.
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.

>>> srdd1 = sqlCtx.jsonRDD(json)
>>> sqlCtx.registerRDDAsTable(srdd1, "table1")
Expand Down Expand Up @@ -1344,7 +1404,7 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,25 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)

def test_infer_schema(self):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
srdd = self.sqlCtx.inferSchema(rdd)
self.assertEqual([], srdd.map(lambda r: r.l).first())
self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
srdd.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])

srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
self.assertEqual(srdd.schema(), srdd2.schema())
self.assertEqual({}, srdd2.map(lambda r: r.d).first())
self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
srdd2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])

def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ trait PrimitiveType extends DataType {
}

object PrimitiveType {
private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap

/** Given the string representation of a type, return its DataType */
Expand Down