Skip to content

Commit

Permalink
[SPARK-16259][PYSPARK] cleanup options in DataFrame read/write API
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

There are some duplicated code for options in DataFrame reader/writer API, this PR clean them up, it also fix a bug for `escapeQuotes` of csv().

## How was this patch tested?

Existing tests.

Author: Davies Liu <[email protected]>

Closes #13948 from davies/csv_options.
  • Loading branch information
Davies Liu authored and zsxwing committed Jun 29, 2016
1 parent 22b4072 commit 345212b
Showing 1 changed file with 20 additions and 99 deletions.
119 changes: 20 additions & 99 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,84 +44,20 @@ def to_str(value):
return str(value)


class ReaderUtils(object):
class OptionUtils(object):

def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
allowComments, allowUnquotedFieldNames, allowSingleQuotes,
allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
mode, columnNameOfCorruptRecord):
def _set_opts(self, schema=None, **options):
"""
Set options based on the Json optional parameters
Set named options (filter out those the value is None)
"""
if schema is not None:
self.schema(schema)
if primitivesAsString is not None:
self.option("primitivesAsString", primitivesAsString)
if prefersDecimal is not None:
self.option("prefersDecimal", prefersDecimal)
if allowComments is not None:
self.option("allowComments", allowComments)
if allowUnquotedFieldNames is not None:
self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
if allowSingleQuotes is not None:
self.option("allowSingleQuotes", allowSingleQuotes)
if allowNumericLeadingZero is not None:
self.option("allowNumericLeadingZero", allowNumericLeadingZero)
if allowBackslashEscapingAnyCharacter is not None:
self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
if mode is not None:
self.option("mode", mode)
if columnNameOfCorruptRecord is not None:
self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)

def _set_csv_opts(self, schema, sep, encoding, quote, escape,
comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, maxMalformedLogPerPartition, mode):
"""
Set options based on the CSV optional parameters
"""
if schema is not None:
self.schema(schema)
if sep is not None:
self.option("sep", sep)
if encoding is not None:
self.option("encoding", encoding)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if comment is not None:
self.option("comment", comment)
if header is not None:
self.option("header", header)
if inferSchema is not None:
self.option("inferSchema", inferSchema)
if ignoreLeadingWhiteSpace is not None:
self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
if ignoreTrailingWhiteSpace is not None:
self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
if nullValue is not None:
self.option("nullValue", nullValue)
if nanValue is not None:
self.option("nanValue", nanValue)
if positiveInf is not None:
self.option("positiveInf", positiveInf)
if negativeInf is not None:
self.option("negativeInf", negativeInf)
if dateFormat is not None:
self.option("dateFormat", dateFormat)
if maxColumns is not None:
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if maxMalformedLogPerPartition is not None:
self.option("maxMalformedLogPerPartition", maxMalformedLogPerPartition)
if mode is not None:
self.option("mode", mode)


class DataFrameReader(ReaderUtils):
for k, v in options.items():
if v is not None:
self.option(k, v)


class DataFrameReader(OptionUtils):
"""
Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read`
Expand Down Expand Up @@ -270,7 +206,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
[('age', 'bigint'), ('name', 'string')]
"""
self._set_json_opts(
self._set_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
Expand Down Expand Up @@ -413,7 +349,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
self._set_csv_opts(
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
Expand Down Expand Up @@ -484,7 +420,7 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
return self._df(self._jreader.jdbc(url, table, jprop))


class DataFrameWriter(object):
class DataFrameWriter(OptionUtils):
"""
Interface used to write a :class:`DataFrame` to external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
Expand Down Expand Up @@ -649,8 +585,7 @@ def json(self, path, mode=None, compression=None):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.json(path)

@since(1.4)
Expand All @@ -676,8 +611,7 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.parquet(path)

@since(1.6)
Expand All @@ -692,8 +626,7 @@ def text(self, path, compression=None):
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.text(path)

@since(2.0)
Expand Down Expand Up @@ -731,20 +664,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if compression is not None:
self.option("compression", compression)
if sep is not None:
self.option("sep", sep)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if header is not None:
self.option("header", header)
if nullValue is not None:
self.option("nullValue", nullValue)
if escapeQuotes is not None:
self.option("escapeQuotes", nullValue)
self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
nullValue=nullValue, escapeQuotes=escapeQuotes)
self._jwrite.csv(path)

@since(1.5)
Expand Down Expand Up @@ -803,7 +724,7 @@ def jdbc(self, url, table, mode=None, properties=None):
self._jwrite.mode(mode).jdbc(url, table, jprop)


class DataStreamReader(ReaderUtils):
class DataStreamReader(OptionUtils):
"""
Interface used to load a streaming :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
Expand Down Expand Up @@ -965,7 +886,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
>>> json_sdf.schema == sdf_schema
True
"""
self._set_json_opts(
self._set_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
Expand Down Expand Up @@ -1095,7 +1016,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> csv_sdf.schema == sdf_schema
True
"""
self._set_csv_opts(
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
Expand Down

0 comments on commit 345212b

Please sign in to comment.