Skip to content

Commit

Permalink
[SPARK-22566][PYTHON] Revert branch removal, add regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
gberger-palantir committed Dec 11, 2017
1 parent 5131db2 commit 44a1879
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,14 @@ def _createFromRDD(self, rdd, schema, samplingRatio):
Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
"""
if schema is None or isinstance(schema, (list, tuple)):
schema = self._inferSchema(rdd, samplingRatio, names=schema)
converter = _create_converter(schema)
struct = self._inferSchema(rdd, samplingRatio, names=schema)
converter = _create_converter(struct)
rdd = rdd.map(converter)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct

elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
Expand All @@ -401,9 +406,14 @@ def _createFromLocal(self, data, schema):
data = list(data)

if schema is None or isinstance(schema, (list, tuple)):
schema = self._inferSchemaFromList(data, names=schema)
converter = _create_converter(schema)
struct = self._inferSchemaFromList(data, names=schema)
converter = _create_converter(struct)
data = map(converter, data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct

elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,10 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{'a': 1}], ["b"])
self.assertEqual(df.columns, ['b'])

def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
df = self.spark.createDataFrame(data)
Expand Down

0 comments on commit 44a1879

Please sign in to comment.