Skip to content

Commit

Permalink
Create datetime.date directly instead of creating datetime64[ns] as i…
Browse files Browse the repository at this point in the history
…ntermediate data.
  • Loading branch information
ueshin committed Feb 15, 2019
1 parent 71170e7 commit 8ac7925
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
5 changes: 2 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck):

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps
_arrow_column_to_pandas, _check_series_localize_timestamps

s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type))
s = _check_series_localize_timestamps(s, self._timezone)
return s

Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,14 +2107,13 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _check_dataframe_convert_date, \
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
pdf = _arrow_table_to_pandas(table, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def setUpClass(cls):
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]

# TODO: remove version check once minimum pyarrow version is 0.10.0
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)

@classmethod
def tearDownClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self):
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
(2, None,),
(3, date(2100, 4, 4),)]
(3, date(2100, 4, 4),),
(4, date(2262, 4, 12),)]
df = self.spark.createDataFrame(data, schema=schema)

date_copy = pandas_udf(lambda t: t, returnType=DateType())
Expand Down
55 changes: 35 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,38 +1681,53 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _check_series_convert_date(series, data_type):
"""
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
def _arrow_column_to_pandas(column, data_type):
""" Convert Arrow Column to pandas Series.
If the given column is a date type column, creates a series of datetime.date directly instead
of creating datetime64[ns] as intermediate data.
:param series: pandas.Series
:param data_type: a Spark data type for the series
:param series: pyarrow.lib.Column
:param data_type: a Spark data type for the column
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType:
return series.dt.date
# Since Arrow 0.11.0, support date_as_object to return datetime.date instead of np.datetime64.
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
if type(data_type) == DateType:
return pd.Series(column.to_pylist(), name=column.name)
else:
return column.to_pandas()
else:
return series
return column.to_pandas(date_as_object=True)


def _arrow_table_to_pandas(table, schema):
""" Convert Arrow Table to pandas DataFrame.
def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11
or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as
intermediate data.
Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
use datetime.date to match the behavior with when Arrow optimization is disabled.
:param pdf: pandas.DataFrame
:param schema: a Spark schema of the pandas.DataFrame
:param table: pyarrow.lib.Table
:param schema: a Spark schema of the pyarrow.lib.Table
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"):
for field in schema:
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
# Since Arrow 0.11.0, support date_as_object to return datetime.date instead of np.datetime64.
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
if any(type(field.dataType) == DateType for field in schema):
return pd.concat([_arrow_column_to_pandas(column, field.dataType)
for column, field in zip(table.itercolumns(), schema)], axis=1)
else:
return table.to_pandas()
else:
return table.to_pandas(date_as_object=True)


def _get_local_timezone():
Expand Down

0 comments on commit 8ac7925

Please sign in to comment.