From 44a1879919a61d732eea176e26ce6a79549984a0 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 11 Dec 2017 13:18:12 +0000 Subject: [PATCH] [SPARK-22566][PYTHON] Revert branch removal, add regression test --- python/pyspark/sql/session.py | 18 ++++++++++++++---- python/pyspark/sql/tests.py | 4 ++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index be824cf0ca40b..5815e55225b46 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -380,9 +380,14 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - schema = self._inferSchema(rdd, samplingRatio, names=schema) - converter = _create_converter(schema) + struct = self._inferSchema(rdd, samplingRatio, names=schema) + converter = _create_converter(struct) rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) @@ -401,9 +406,14 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - schema = self._inferSchemaFromList(data, names=schema) - converter = _create_converter(schema) + struct = self._inferSchemaFromList(data, names=schema) + converter = _create_converter(struct) data = map(converter, data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 45554c25400fe..6297f396a973c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -872,6 +872,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data)