From 518fdd4f3d0e968cef2e3ba1b0220daee5ee7778 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Tue, 21 Nov 2017 15:06:25 +0000 Subject: [PATCH 01/15] [SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark DF conversion --- python/pyspark/sql/session.py | 7 ++++--- python/pyspark/sql/types.py | 14 +++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 47c58bb28221c..9516594a461df 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -324,11 +324,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -337,7 +338,7 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, [_infer_schema(row, names) for row in data]) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema @@ -405,7 +406,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fe62f60dd6d0e..3e34237513dd1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1072,7 +1072,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1083,7 +1083,8 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1108,19 +1109,22 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + if name is not None: + raise TypeError("Can not merge type %s and %s in column %s" % (type(a), type(b), name)) + else: + raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), f.name)) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: From b29434e939cbb8e5eb3f3fc3e36e33bc8eab2cf1 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Wed, 22 Nov 2017 10:54:25 +0000 Subject: [PATCH 02/15] [SPARK-22566][PYTHON] Simplify conditional --- python/pyspark/sql/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3e34237513dd1..43ffe35e6a265 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1116,10 +1116,10 @@ def _merge_type(a, b, name=None): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - if name is not None: - raise TypeError("Can not merge type %s and %s in column %s" % (type(a), type(b), name)) - else: + if name is None: raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + else: + raise TypeError("Can not merge type %s and %s in column %s" % (type(a), type(b), name)) # same type if isinstance(a, StructType): From 6aa99631154cc0a6809832a1333ee46959efcbfb Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Wed, 22 Nov 2017 11:20:21 +0000 Subject: [PATCH 03/15] [SPARK-22566][PYTHON] Pass name when recursing in _merge_types --- python/pyspark/sql/types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 43ffe35e6a265..ef0974676bb73 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1119,7 +1119,7 @@ def _merge_type(a, b, name=None): if name is None: raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) else: - raise TypeError("Can not merge type %s and %s in column %s" % (type(a), type(b), name)) + raise TypeError("Can not merge type %s and %s in field '%s'" % (type(a), type(b), name)) # same type if isinstance(a, StructType): @@ -1133,11 +1133,11 @@ def _merge_type(a, b, name=None): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, name=name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name=name), + _merge_type(a.valueType, b.valueType, name=name), True) else: return a From 1e0072b86eded88c0a2d7b07c5f17339389d8120 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Thu, 23 Nov 2017 18:30:42 +0000 Subject: [PATCH 04/15] [SPARK-22566][PYTHON] Make _merge_type error message "recursive" --- python/pyspark/sql/types.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ef0974676bb73..621506d955107 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1109,22 +1109,23 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b, name=None): +def _merge_type(a, b, path=''): if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - if name is None: + if path == '': raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) else: - raise TypeError("Can not merge type %s and %s in field '%s'" % (type(a), type(b), name)) + raise TypeError("%s: Can not merge type %s and %s" % (path, type(a), type(b))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), f.name)) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + path='%s.structField("%s")' % (path, f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1133,11 +1134,11 @@ def _merge_type(a, b, name=None): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType, name=name), True) + return ArrayType(_merge_type(a.elementType, b.elementType, path=path + '.arrayElement'), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType, name=name), - _merge_type(a.valueType, b.valueType, name=name), + return MapType(_merge_type(a.keyType, b.keyType, path=path + '.mapKey'), + _merge_type(a.valueType, b.valueType, path=path + '.mapValue'), True) else: return a From 61ace285bdbe6bff36ef5e0d8c8e10f9bc5a227b Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Thu, 23 Nov 2017 18:31:10 +0000 Subject: [PATCH 05/15] [SPARK-22566][PYTHON] Add tests for _merge_type --- python/pyspark/sql/tests.py | 61 +++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 762afe0d730f3..76b8d6f2ca1c6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -62,6 +62,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings +from pyspark.sql.types import _merge_type from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -1722,6 +1723,66 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type(ArrayType(LongType()), ArrayType(LongType())), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'arrayElement'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'mapKey'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'mapValue'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + _merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(DoubleType())), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + _merge_type( + StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", MapType(StringType(), DoubleType())), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)\.arrayElement\.mapKey'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() From 3346a6ca65cc46a4912d81fd9a5086a406e875c5 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Thu, 23 Nov 2017 19:04:07 +0000 Subject: [PATCH 06/15] [SPARK-22566][PYTHON] Lint Python --- python/pyspark/sql/tests.py | 39 ++++++++++++++++++++++++++----------- python/pyspark/sql/types.py | 3 ++- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 76b8d6f2ca1c6..641e22d499660 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1729,7 +1729,10 @@ def test_merge_type(self): self.assertEqual(_merge_type(LongType(), LongType()), LongType()) - self.assertEqual(_merge_type(ArrayType(LongType()), ArrayType(LongType())), ArrayType(LongType())) + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) with self.assertRaisesRegexp(TypeError, 'arrayElement'): _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) @@ -1761,22 +1764,36 @@ def test_merge_type(self): ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): _merge_type( - StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), - StructType([StructField("f1", ArrayType(DoubleType())), StructField("f2", StringType())])) + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) self.assertEqual(_merge_type( - StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]), - StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]) - ), StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())])) + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): _merge_type( - StructType([StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())]), - StructType([StructField("f1", MapType(StringType(), DoubleType())), StructField("f2", StringType())])) + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) self.assertEqual(_merge_type( - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) - ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)\.arrayElement\.mapKey'): _merge_type( StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 621506d955107..9f5f3e9d007cb 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1134,7 +1134,8 @@ def _merge_type(a, b, path=''): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType, path=path + '.arrayElement'), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + path=path + '.arrayElement'), True) elif isinstance(a, MapType): return MapType(_merge_type(a.keyType, b.keyType, path=path + '.mapKey'), From 8665115a524b4b2c7f9e1cd1cb69f8038fb1c904 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Tue, 28 Nov 2017 18:17:34 +0000 Subject: [PATCH 07/15] [SPARK-22566][PYTHON] Remove r prefix for assertRaisesRegexp --- python/pyspark/sql/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 641e22d499660..7c00c494757d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1753,7 +1753,7 @@ def test_merge_type(self): StructType([StructField("f1", LongType()), StructField("f2", StringType())]), StructType([StructField("f1", LongType()), StructField("f2", StringType())]) ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): _merge_type( StructType([StructField("f1", LongType()), StructField("f2", StringType())]), StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) @@ -1762,7 +1762,7 @@ def test_merge_type(self): StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): _merge_type( StructType([ StructField("f1", ArrayType(LongType())), @@ -1781,7 +1781,7 @@ def test_merge_type(self): ), StructType([ StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): _merge_type( StructType([ StructField("f1", MapType(StringType(), LongType())), @@ -1794,7 +1794,7 @@ def test_merge_type(self): StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) - with self.assertRaisesRegexp(TypeError, r'structField\("f1"\)\.arrayElement\.mapKey'): + with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)\.arrayElement\.mapKey'): _merge_type( StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) From c6032518774022f026d33656787a6387d6f83b5c Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Wed, 29 Nov 2017 16:13:46 +0000 Subject: [PATCH 08/15] [SPARK-22566][PYTHON] Make error message similar to #18521 --- python/pyspark/sql/tests.py | 23 ++++++++++++++++------- python/pyspark/sql/types.py | 28 +++++++++++++++++++--------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7c00c494757d9..00c8131d92c06 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1733,18 +1733,18 @@ def test_merge_type(self): ArrayType(LongType()), ArrayType(LongType()) ), ArrayType(LongType())) - with self.assertRaisesRegexp(TypeError, 'arrayElement'): + with self.assertRaisesRegexp(TypeError, 'element in array'): _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) self.assertEqual(_merge_type( MapType(StringType(), LongType()), MapType(StringType(), LongType()) ), MapType(StringType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'mapKey'): + with self.assertRaisesRegexp(TypeError, 'key of map'): _merge_type( MapType(StringType(), LongType()), MapType(DoubleType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'mapValue'): + with self.assertRaisesRegexp(TypeError, 'value of map'): _merge_type( MapType(StringType(), LongType()), MapType(StringType(), DoubleType())) @@ -1753,16 +1753,25 @@ def test_merge_type(self): StructType([StructField("f1", LongType()), StructField("f2", StringType())]), StructType([StructField("f1", LongType()), StructField("f2", StringType())]) ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'field f1'): _merge_type( StructType([StructField("f1", LongType()), StructField("f2", StringType())]), StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + self.assertEqual(_merge_type( StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): _merge_type( StructType([ StructField("f1", ArrayType(LongType())), @@ -1781,7 +1790,7 @@ def test_merge_type(self): ), StructType([ StructField("f1", MapType(StringType(), LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)'): + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): _merge_type( StructType([ StructField("f1", MapType(StringType(), LongType())), @@ -1794,7 +1803,7 @@ def test_merge_type(self): StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) - with self.assertRaisesRegexp(TypeError, 'structField\("f1"\)\.arrayElement\.mapKey'): + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): _merge_type( StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9f5f3e9d007cb..6395bfdc710d9 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1109,23 +1109,33 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b, path=''): +def _merge_type_path(path, addition): + if path: + return "%s in %s" % (addition, path) + return addition + + +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - if path == '': - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) - else: - raise TypeError("%s: Can not merge type %s and %s" % (path, type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), - path='%s.structField("%s")' % (path, f.name))) + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1135,11 +1145,11 @@ def _merge_type(a, b, path=''): elif isinstance(a, ArrayType): return ArrayType(_merge_type(a.elementType, b.elementType, - path=path + '.arrayElement'), True) + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType, path=path + '.mapKey'), - _merge_type(a.valueType, b.valueType, path=path + '.mapValue'), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a From 2240a42847ed8155375a99c30a54d0749c655966 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Thu, 30 Nov 2017 10:26:58 +0000 Subject: [PATCH 09/15] [SPARK-22566][PYTHON] Remove unused function --- python/pyspark/sql/types.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6395bfdc710d9..9a389d44e6b2b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1109,12 +1109,6 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type_path(path, addition): - if path: - return "%s in %s" % (addition, path) - return addition - - def _merge_type(a, b, name=None): if name is None: new_msg = lambda msg: msg From 41766fa875b987fecf910b7fa8bd9429e27ce88e Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 4 Dec 2017 12:31:15 +0000 Subject: [PATCH 10/15] [SPARK-22566][PYTHON] Use generator expression --- python/pyspark/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9516594a461df..b6c957fdf9808 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -338,7 +338,7 @@ def _inferSchemaFromList(self, data, names=None): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, [_infer_schema(row, names) for row in data]) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema From 0103045f751e9f1c777673bdebd0632c4e781486 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Tue, 5 Dec 2017 14:52:53 +0000 Subject: [PATCH 11/15] [SPARK-22566][PYTHON] Pass field names to _infer_schema and _inferSchema --- python/pyspark/sql/session.py | 17 ++++++----------- python/pyspark/sql/tests.py | 5 +++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b6c957fdf9808..ad69bce3e94e9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -343,7 +343,7 @@ def _inferSchemaFromList(self, data, names=None): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +360,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +372,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,14 +380,9 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(struct) + schema = self._inferSchema(rdd, samplingRatio, names=schema) + converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00c8131d92c06..45554c25400fe 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -847,6 +847,11 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), From 5131db23bd48b3606c7c823cad9d5d376bae0d00 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Wed, 6 Dec 2017 13:41:39 +0000 Subject: [PATCH 12/15] [SPARK-22566][PYTHON] Remove unnecessary code branch --- python/pyspark/sql/session.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ad69bce3e94e9..be824cf0ca40b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -401,14 +401,9 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data, names=schema) - converter = _create_converter(struct) + schema = self._inferSchemaFromList(data, names=schema) + converter = _create_converter(schema) data = map(converter, data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) From 44a1879919a61d732eea176e26ce6a79549984a0 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 11 Dec 2017 13:18:12 +0000 Subject: [PATCH 13/15] [SPARK-22566][PYTHON] Revert branch removal, add regression test --- python/pyspark/sql/session.py | 18 ++++++++++++++---- python/pyspark/sql/tests.py | 4 ++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index be824cf0ca40b..5815e55225b46 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -380,9 +380,14 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - schema = self._inferSchema(rdd, samplingRatio, names=schema) - converter = _create_converter(schema) + struct = self._inferSchema(rdd, samplingRatio, names=schema) + converter = _create_converter(struct) rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) @@ -401,9 +406,14 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - schema = self._inferSchemaFromList(data, names=schema) - converter = _create_converter(schema) + struct = self._inferSchemaFromList(data, names=schema) + converter = _create_converter(struct) data = map(converter, data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 45554c25400fe..6297f396a973c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -872,6 +872,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) From 404fdbb1b5b265c6f1f651f01f42eb62d598b788 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Tue, 12 Dec 2017 11:25:12 +0000 Subject: [PATCH 14/15] [SPARK-22566][PYTHON] Guard against short `names` list --- python/pyspark/sql/tests.py | 4 ++++ python/pyspark/sql/types.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6297f396a973c..bb8956b0f3bf9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -847,6 +847,10 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + def test_infer_schema_fails(self): with self.assertRaisesRegexp(TypeError, 'field a'): self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9a389d44e6b2b..1a9e8e52cbec4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1085,6 +1085,9 @@ def _infer_schema(row, names=None): else: if names is None: names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names = names[:] + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object From 6d171dda179ecdbe95dbc959c961397e08b8b537 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Thu, 14 Dec 2017 15:21:23 +0000 Subject: [PATCH 15/15] [SPARK-22566][PYTHON] Remove unnecessary copying of names list --- python/pyspark/sql/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1a9e8e52cbec4..e1eb208bba9b1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1086,7 +1086,6 @@ def _infer_schema(row, names=None): if names is None: names = ['_%d' % i for i in range(1, len(row) + 1)] elif len(names) < len(row): - names = names[:] names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row)