Skip to content

Commit

Permalink
[SPARK-8060] Improve DataFrame Python reader/writer interface doc and…
Browse files Browse the repository at this point in the history
… testing.
  • Loading branch information
rxin committed Jun 3, 2015
1 parent 5cd6a63 commit c9902fa
Show file tree
Hide file tree
Showing 19 changed files with 109 additions and 97 deletions.
12 changes: 11 additions & 1 deletion python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@


def since(version):
"""
Annotates a function to append the version of Spark the function was added.
"""
import re
indent_p = re.compile(r'\n( +)')

def deco(f):
f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version
indents = indent_p.findall(f.__doc__)
indent = ' ' * (min(len(m) for m in indents) if indents else 0)
f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
return deco


from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.column import Column
Expand All @@ -58,6 +67,7 @@ def deco(f):
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec


__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
Expand Down
15 changes: 10 additions & 5 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,9 @@ def describe(self, *cols):
This include count, mean, stddev, min, and max. If no columns are
given, this function computes statistics for all numerical columns.
.. note:: This function is meant for exploratory data analysis, as we make no \
guarantee about the backward compatibility of the schema of the resulting DataFrame.
>>> df.describe().show()
+-------+---+
|summary|age|
Expand All @@ -653,9 +656,11 @@ def describe(self, *cols):
@ignore_unicode_prefix
@since(1.3)
def head(self, n=None):
"""
Returns the first ``n`` rows as a list of :class:`Row`,
or the first :class:`Row` if ``n`` is ``None.``
"""Returns the first ``n`` rows.
If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row.
:param n: int, default 1.
>>> df.head()
Row(age=2, name=u'Alice')
Expand Down Expand Up @@ -1170,8 +1175,8 @@ def freqItems(self, cols, support=None):
"http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
:func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
This function is meant for exploratory data analysis, as we make no guarantee about the
backward compatibility of the schema of the resulting DataFrame.
.. note:: This function is meant for exploratory data analysis, as we make no \
guarantee about the backward compatibility of the schema of the resulting DataFrame.
:param cols: Names of the columns to calculate frequent items for as a list or tuple of
strings.
Expand Down
164 changes: 78 additions & 86 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,24 @@ def _df(self, jdf):

@since(1.4)
def format(self, source):
"""
Specifies the input data source format.
"""Specifies the input data source format.
:param source: string, name of the data source, e.g. 'json', 'parquet'.
>>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
"""
self._jreader = self._jreader.format(source)
return self

@since(1.4)
def schema(self, schema):
"""
Specifies the input schema. Some data sources (e.g. JSON) can
infer the input schema automatically from data. By specifying
the schema here, the underlying data source can skip the schema
"""Specifies the input schema.
Some data sources (e.g. JSON) can infer the input schema automatically from data.
By specifying the schema here, the underlying data source can skip the schema
inference step, and thus speed up data loading.
:param schema: a StructType object
Expand All @@ -69,8 +75,7 @@ def schema(self, schema):

@since(1.4)
def options(self, **options):
"""
Adds input options for the underlying data source.
"""Adds input options for the underlying data source.
"""
for k in options:
self._jreader = self._jreader.option(k, options[k])
Expand All @@ -84,6 +89,10 @@ def load(self, path=None, format=None, schema=None, **options):
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
>>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
if format is not None:
self.format(format)
Expand All @@ -107,31 +116,10 @@ def json(self, path, schema=None):
:param path: string, path to the JSON dataset.
:param schema: an optional :class:`StructType` for the input schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
>>> with open(jsonFile, 'w') as f:
... f.writelines(jsonStrings)
>>> df1 = sqlContext.read.json(jsonFile)
>>> df1.printSchema()
root
|-- field1: long (nullable = true)
|-- field2: string (nullable = true)
|-- field3: struct (nullable = true)
| |-- field4: long (nullable = true)
>>> from pyspark.sql.types import *
>>> schema = StructType([
... StructField("field2", StringType()),
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
>>> df2 = sqlContext.read.json(jsonFile, schema)
>>> df2.printSchema()
root
|-- field2: string (nullable = true)
|-- field3: struct (nullable = true)
| |-- field5: array (nullable = true)
| | |-- element: integer (containsNull = true)
>>> df = sqlContext.read.json('python/test_support/sql/people.json')
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
"""
if schema is not None:
self.schema(schema)
Expand All @@ -141,24 +129,22 @@ def json(self, path, schema=None):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.read.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
:param tableName: string, name of the table.
>>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.registerTempTable('tmpTable')
>>> sqlContext.read.table('tmpTable').dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.table(tableName))

@since(1.4)
def parquet(self, *path):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlContext.read.parquet(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
>>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))

Expand Down Expand Up @@ -221,43 +207,49 @@ def __init__(self, df):

@since(1.4)
def mode(self, saveMode):
"""
Specifies the behavior when data or table already exists. Options include:
"""Specifies the behavior when data or table already exists.
Options include:
* `append`: Append contents of this :class:`DataFrame` to existing data.
* `overwrite`: Overwrite existing data.
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self._jwrite = self._jwrite.mode(saveMode)
return self

@since(1.4)
def format(self, source):
"""
Specifies the underlying output data source. Built-in options include
"parquet", "json", etc.
"""Specifies the underlying output data source.
:param source: string, name of the data source, e.g. 'json', 'parquet'.
>>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self._jwrite = self._jwrite.format(source)
return self

@since(1.4)
def options(self, **options):
"""
Adds output options for the underlying data source.
"""Adds output options for the underlying data source.
"""
for k in options:
self._jwrite = self._jwrite.option(k, options[k])
return self

@since(1.4)
def partitionBy(self, *cols):
"""
Partitions the output by the given columns on the file system.
"""Partitions the output by the given columns on the file system.
If specified, the output is laid out on the file system similar
to Hive's partitioning scheme.
:param cols: name of columns
>>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
Expand All @@ -266,8 +258,7 @@ def partitionBy(self, *cols):

@since(1.4)
def save(self, path=None, format=None, mode="error", **options):
"""
Saves the contents of the :class:`DataFrame` to a data source.
"""Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the ``format`` and a set of ``options``.
If ``format`` is not specified, the default data source configured by
Expand All @@ -285,6 +276,8 @@ def save(self, path=None, format=None, mode="error", **options):
:param format: the format used to save
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode).options(**options)
if format is not None:
Expand All @@ -296,8 +289,8 @@ def save(self, path=None, format=None, mode="error", **options):

@since(1.4)
def insertInto(self, tableName, overwrite=False):
"""
Inserts the content of the :class:`DataFrame` to the specified table.
"""Inserts the content of the :class:`DataFrame` to the specified table.
It requires that the schema of the class:`DataFrame` is the same as the
schema of the table.
Expand All @@ -307,8 +300,7 @@ def insertInto(self, tableName, overwrite=False):

@since(1.4)
def saveAsTable(self, name, format=None, mode="error", **options):
"""
Saves the content of the :class:`DataFrame` as the specified table.
"""Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the
save mode, specified by the `mode` function (default to throwing an exception).
Expand All @@ -328,13 +320,11 @@ def saveAsTable(self, name, format=None, mode="error", **options):
self.mode(mode).options(**options)
if format is not None:
self.format(format)
return self._jwrite.saveAsTable(name)
self._jwrite.saveAsTable(name)

@since(1.4)
def json(self, path, mode="error"):
"""
Saves the content of the :class:`DataFrame` in JSON format at the
specified path.
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:
Expand All @@ -346,14 +336,14 @@ def json(self, path, mode="error"):
:param path: the path in any Hadoop supported file system
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
return self._jwrite.mode(mode).json(path)
self._jwrite.mode(mode).json(path)

@since(1.4)
def parquet(self, path, mode="error"):
"""
Saves the content of the :class:`DataFrame` in Parquet format at the
specified path.
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:
Expand All @@ -365,14 +355,14 @@ def parquet(self, path, mode="error"):
:param path: the path in any Hadoop supported file system
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
return self._jwrite.mode(mode).parquet(path)
self._jwrite.mode(mode).parquet(path)

@since(1.4)
def jdbc(self, url, table, mode="error", properties={}):
"""
Saves the content of the :class:`DataFrame` to a external database table
via JDBC.
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
In the case the table already exists in the external database,
behavior of this function depends on the save mode, specified by the `mode`
Expand All @@ -383,12 +373,15 @@ def jdbc(self, url, table, mode="error", properties={}):
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
:param url: a JDBC URL of the form `jdbc:subprotocol:subname`
.. warning:: Don't create too many partitions in parallel on a large cluster;
otherwise Spark might crash your external database systems.
:param url: a JDBC URL of the form ``jdbc:subprotocol:subname``
:param table: Name of the table in the external database.
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param mode: one of ``append``, ``overwrite``, ``error``, ``ignore`` (default: ``error``)
:param properties: JDBC database connection arguments, a list of
arbitrary string tag/value. Normally at least a
"user" and "password" property should be included.
arbitrary string tag/value. Normally at least a
"user" and "password" property should be included.
"""
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
for k in properties:
Expand All @@ -398,24 +391,23 @@ def jdbc(self, url, table, mode="error", properties={}):

def _test():
import doctest
import os
import tempfile
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.readwriter

os.chdir(os.environ["SPARK_HOME"])

globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')

globs['tempfile'] = tempfile
globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
'"field6":[{"field7": "row2"}]}',
'{"field1" : null, "field2": "row3", '
'"field3":{"field4":33, "field5": []}}'
]
globs['jsonStrings'] = jsonStrings
globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')

(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
Expand Down
Loading

0 comments on commit c9902fa

Please sign in to comment.