diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 440684d3edfa6..95eca76fa9888 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1906,9 +1906,9 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 86db16eca7889..6e5eec48e8aca 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -493,15 +493,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): data types will be used to coerce the data in Pandas to Arrow conversion. """ from pyspark.serializers import ArrowSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ - _old_pandas_exception_message, TimestampType - from pyspark.sql.utils import _require_minimum_pyarrow_version - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pyspark.sql.utils import require_minimum_pandas_version, \ + require_minimum_pyarrow_version + + require_minimum_pandas_version() + require_minimum_pyarrow_version() - _require_minimum_pyarrow_version() + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6fdfda1cc831b..b977160af566d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -53,7 +53,8 @@ try: import pandas try: - import pandas.api + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() _have_pandas = True except: _have_old_pandas = True @@ -2600,7 +2601,7 @@ def test_to_pandas(self): @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") def test_to_pandas_old(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() @unittest.skipIf(not _have_pandas, "Pandas not installed") @@ -2643,7 +2644,7 @@ def test_create_dataframe_from_old_pandas(self): pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}) with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self.spark.createDataFrame(pdf) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 46d9a417414b5..063264a89379c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1678,13 +1678,6 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _old_pandas_exception_message(e): - """ Create an error message for importing old Pandas. - """ - msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process" - return "%s\n%s" % (_exception_message(e), msg) - - def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1693,10 +1686,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive """ - try: - from pandas.api.types import is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? @@ -1714,10 +1707,10 @@ def _check_series_convert_timestamps_internal(s, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone """ - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): tz = timezone or 'tzlocal()' @@ -1737,11 +1730,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): :param to_timezone: the timezone to convert to. if None then use local timezone :return pandas.Series where if it is a timestamp, has been converted to tz-naive """ - try: - import pandas as pd - from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype from_tz = from_timezone or 'tzlocal()' to_tz = to_timezone or 'tzlocal()' # TODO: handle nested timestamps, such as ArrayType(TimestampType())? diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 50c87ba1ac882..123138117fdc3 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _create_udf(f, returnType, evalType): if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: import inspect - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() argspec = inspect.getargspec(f) if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cc7dabb64b3ec..fb7d42a35d8f4 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -112,7 +112,16 @@ def toJArray(gateway, jtype, arr): return jarr -def _require_minimum_pyarrow_version(): +def require_minimum_pandas_version(): + """ Raise ImportError if minimum version of Pandas is not installed + """ + from distutils.version import LooseVersion + import pandas + if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + + +def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ from distutils.version import LooseVersion