Skip to content

Commit

Permalink
[SPARK-22566][PYTHON] Better error message for _merge_type in Panda…
Browse files Browse the repository at this point in the history
…s to Spark DF conversion

## What changes were proposed in this pull request?

It provides a better error message when doing `spark_session.createDataFrame(pandas_df)` with no schema and an error occurs in the schema inference due to incompatible types.

The Pandas column names are propagated down and the error message mentions which column had the merging error.

https://issues.apache.org/jira/browse/SPARK-22566

## How was this patch tested?

Manually in the `./bin/pyspark` console, and with new tests: `./python/run-tests`

<img width="873" alt="screen shot 2017-11-21 at 13 29 49" src="https://user-images.githubusercontent.com/3977115/33080121-382274e0-cecf-11e7-808f-057a65bb7b00.png">

I state that the contribution is my original work and that I license the work to the Apache Spark project under the project’s open source license.

Author: Guilherme Berger <[email protected]>

Closes #19792 from gberger/master.
  • Loading branch information
gberger-palantir authored and ueshin committed Jan 8, 2018
1 parent 71d65a3 commit 3e40eb3
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 16 deletions.
17 changes: 9 additions & 8 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None):

return DataFrame(jdf, self._wrapped)

def _inferSchemaFromList(self, data):
def _inferSchemaFromList(self, data, names=None):
"""
Infer schema from list of Row or tuple.
:param data: list of Row or tuple
:param names: list of column names
:return: :class:`pyspark.sql.types.StructType`
"""
if not data:
Expand All @@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data):
if type(first) is dict:
warnings.warn("inferring schema from dict is deprecated,"
"please use pyspark.sql.Row instead")
schema = reduce(_merge_type, map(_infer_schema, data))
schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
if _has_nulltype(schema):
raise ValueError("Some of types cannot be determined after inferring")
return schema

def _inferSchema(self, rdd, samplingRatio=None):
def _inferSchema(self, rdd, samplingRatio=None, names=None):
"""
Infer schema from an RDD of Row or tuple.
Expand All @@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None):
"Use pyspark.sql.Row instead")

if samplingRatio is None:
schema = _infer_schema(first)
schema = _infer_schema(first, names=names)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(row))
schema = _merge_type(schema, _infer_schema(row, names=names))
if not _has_nulltype(schema):
break
else:
Expand All @@ -372,15 +373,15 @@ def _inferSchema(self, rdd, samplingRatio=None):
else:
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)
schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
return schema

def _createFromRDD(self, rdd, schema, samplingRatio):
"""
Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
"""
if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchema(rdd, samplingRatio)
struct = self._inferSchema(rdd, samplingRatio, names=schema)
converter = _create_converter(struct)
rdd = rdd.map(converter)
if isinstance(schema, (list, tuple)):
Expand All @@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema):
data = list(data)

if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchemaFromList(data)
struct = self._inferSchemaFromList(data, names=schema)
converter = _create_converter(struct)
data = map(converter, data)
if isinstance(schema, (list, tuple)):
Expand Down
100 changes: 100 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
Expand Down Expand Up @@ -898,6 +899,15 @@ def test_infer_schema(self):
result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])

def test_infer_schema_not_enough_names(self):
df = self.spark.createDataFrame([["a", "b"]], ["col1"])
self.assertEqual(df.columns, ['col1', '_2'])

def test_infer_schema_fails(self):
with self.assertRaisesRegexp(TypeError, 'field a'):
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
schema=["a", "b"], samplingRatio=0.99)

def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
Expand All @@ -918,6 +928,10 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{'a': 1}], ["b"])
self.assertEqual(df.columns, ['b'])

def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
df = self.spark.createDataFrame(data)
Expand Down Expand Up @@ -1772,6 +1786,92 @@ def test_infer_long_type(self):
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_merge_type(self):
self.assertEqual(_merge_type(LongType(), NullType()), LongType())
self.assertEqual(_merge_type(NullType(), LongType()), LongType())

self.assertEqual(_merge_type(LongType(), LongType()), LongType())

self.assertEqual(_merge_type(
ArrayType(LongType()),
ArrayType(LongType())
), ArrayType(LongType()))
with self.assertRaisesRegexp(TypeError, 'element in array'):
_merge_type(ArrayType(LongType()), ArrayType(DoubleType()))

self.assertEqual(_merge_type(
MapType(StringType(), LongType()),
MapType(StringType(), LongType())
), MapType(StringType(), LongType()))
with self.assertRaisesRegexp(TypeError, 'key of map'):
_merge_type(
MapType(StringType(), LongType()),
MapType(DoubleType(), LongType()))
with self.assertRaisesRegexp(TypeError, 'value of map'):
_merge_type(
MapType(StringType(), LongType()),
MapType(StringType(), DoubleType()))

self.assertEqual(_merge_type(
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
StructType([StructField("f1", LongType()), StructField("f2", StringType())])
), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'field f1'):
_merge_type(
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))

self.assertEqual(_merge_type(
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
_merge_type(
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))

self.assertEqual(_merge_type(
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
_merge_type(
StructType([
StructField("f1", ArrayType(LongType())),
StructField("f2", StringType())]),
StructType([
StructField("f1", ArrayType(DoubleType())),
StructField("f2", StringType())]))

self.assertEqual(_merge_type(
StructType([
StructField("f1", MapType(StringType(), LongType())),
StructField("f2", StringType())]),
StructType([
StructField("f1", MapType(StringType(), LongType())),
StructField("f2", StringType())])
), StructType([
StructField("f1", MapType(StringType(), LongType())),
StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
_merge_type(
StructType([
StructField("f1", MapType(StringType(), LongType())),
StructField("f2", StringType())]),
StructType([
StructField("f1", MapType(StringType(), DoubleType())),
StructField("f2", StringType())]))

self.assertEqual(_merge_type(
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
_merge_type(
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
)

def test_filter_with_datetime(self):
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
date = time.date()
Expand Down
28 changes: 20 additions & 8 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def _infer_type(obj):
raise TypeError("not supported type: %s" % type(obj))


def _infer_schema(row):
def _infer_schema(row, names=None):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
Expand All @@ -1084,7 +1084,10 @@ def _infer_schema(row):
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
names = ['_%d' % i for i in range(1, len(row) + 1)]
if names is None:
names = ['_%d' % i for i in range(1, len(row) + 1)]
elif len(names) < len(row):
names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
items = zip(names, row)

elif hasattr(row, "__dict__"): # object
Expand All @@ -1109,19 +1112,27 @@ def _has_nulltype(dt):
return isinstance(dt, NullType)


def _merge_type(a, b):
def _merge_type(a, b, name=None):
if name is None:
new_msg = lambda msg: msg
new_name = lambda n: "field %s" % n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_name = lambda n: "field %s in %s" % (n, name)

if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))

# same type
if isinstance(a, StructType):
nfs = dict((f.name, f.dataType) for f in b.fields)
fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()),
name=new_name(f.name)))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
Expand All @@ -1130,11 +1141,12 @@ def _merge_type(a, b):
return StructType(fields)

elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.elementType, b.elementType), True)
return ArrayType(_merge_type(a.elementType, b.elementType,
name='element in array %s' % name), True)

elif isinstance(a, MapType):
return MapType(_merge_type(a.keyType, b.keyType),
_merge_type(a.valueType, b.valueType),
return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name),
_merge_type(a.valueType, b.valueType, name='value of map %s' % name),
True)
else:
return a
Expand Down

0 comments on commit 3e40eb3

Please sign in to comment.