Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22874][PYSPARK][SQL] Modify checking pandas version to use LooseVersion. #20054

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,9 +1906,9 @@ def toPandas(self):
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
from pyspark.sql.utils import _require_minimum_pyarrow_version
from pyspark.sql.utils import require_minimum_pyarrow_version
import pyarrow
_require_minimum_pyarrow_version()
require_minimum_pyarrow_version()
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
Expand Down
15 changes: 7 additions & 8 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, \
_old_pandas_exception_message, TimestampType
from pyspark.sql.utils import _require_minimum_pyarrow_version
try:
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()

_require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
try:
import pandas
try:
import pandas.api
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
_have_pandas = True
except:
_have_old_pandas = True
Expand Down Expand Up @@ -2600,7 +2601,7 @@ def test_to_pandas(self):
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_to_pandas_old(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
Expand Down Expand Up @@ -2643,7 +2644,7 @@ def test_create_dataframe_from_old_pandas(self):
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self.spark.createDataFrame(pdf)


Expand Down
33 changes: 13 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,13 +1678,6 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _old_pandas_exception_message(e):
""" Create an error message for importing old Pandas.
"""
msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process"
return "%s\n%s" % (_exception_message(e), msg)


def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
Expand All @@ -1693,10 +1686,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
try:
from pandas.api.types import is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

from pandas.api.types import is_datetime64tz_dtype
tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
Expand All @@ -1714,10 +1707,10 @@ def _check_series_convert_timestamps_internal(s, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
try:
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
tz = timezone or 'tzlocal()'
Expand All @@ -1737,11 +1730,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
try:
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
except ImportError as e:
raise ImportError(_old_pandas_exception_message(e))
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or 'tzlocal()'
to_tz = to_timezone or 'tzlocal()'
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def _create_udf(f, returnType, evalType):
if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \
evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
import inspect
from pyspark.sql.utils import _require_minimum_pyarrow_version
from pyspark.sql.utils import require_minimum_pyarrow_version

_require_minimum_pyarrow_version()
require_minimum_pyarrow_version()
argspec = inspect.getargspec(f)

if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \
Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ def toJArray(gateway, jtype, arr):
return jarr


def _require_minimum_pyarrow_version():
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
from distutils.version import LooseVersion
import pandas
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process")


def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
from distutils.version import LooseVersion
Expand Down