diff --git a/python/pyspark/pandas/tests/io/test_io.py b/python/pyspark/pandas/tests/io/test_io.py index d4e61319f229c..6fbdc366dd76a 100644 --- a/python/pyspark/pandas/tests/io/test_io.py +++ b/python/pyspark/pandas/tests/io/test_io.py @@ -22,12 +22,9 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.testing.pandasutils import ( - have_tabulate, - PandasOnSparkTestCase, - tabulate_requirement_message, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.testing.utils import have_tabulate, tabulate_requirement_message # This file contains test cases for 'Serialization / IO / Conversion' diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py index 365d34b1f550e..1d63cafe19b42 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py @@ -24,12 +24,8 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.testing.pandasutils import ( - have_matplotlib, - matplotlib_requirement_message, - PandasOnSparkTestCase, - TestUtils, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message if have_matplotlib: import matplotlib diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 8d197649aaebe..5308932573330 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -23,12 +23,8 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.testing.pandasutils import ( - have_plotly, - plotly_requirement_message, - PandasOnSparkTestCase, - TestUtils, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.utils import have_plotly, plotly_requirement_message from pyspark.pandas.utils import name_like_string if have_plotly: diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py b/python/pyspark/pandas/tests/plot/test_series_plot.py index 6e0bdd232fc41..61d114f37b0e8 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot.py @@ -22,7 +22,7 @@ from pyspark import pandas as ps from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase -from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message +from pyspark.testing.utils import have_plotly, plotly_requirement_message class SeriesPlotTestsMixin: diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py index c98c1aeea04e7..0fdcbc9d748e0 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py @@ -24,12 +24,8 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.testing.pandasutils import ( - have_matplotlib, - matplotlib_requirement_message, - PandasOnSparkTestCase, - TestUtils, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message if have_matplotlib: import matplotlib diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py index 1aa175f9308a1..8123af26dbf4b 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py @@ -24,12 +24,8 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option from pyspark.pandas.utils import name_like_string -from pyspark.testing.pandasutils import ( - have_plotly, - plotly_requirement_message, - PandasOnSparkTestCase, - TestUtils, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.utils import have_plotly, plotly_requirement_message if have_plotly: from plotly import express diff --git a/python/pyspark/pandas/tests/series/test_conversion.py b/python/pyspark/pandas/tests/series/test_conversion.py index 71ae858631d4d..7711d05abd76d 100644 --- a/python/pyspark/pandas/tests/series/test_conversion.py +++ b/python/pyspark/pandas/tests/series/test_conversion.py @@ -21,7 +21,7 @@ from pyspark import pandas as ps from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.sqlutils import SQLTestUtils -from pyspark.testing.pandasutils import have_tabulate, tabulate_requirement_message +from pyspark.testing.utils import have_tabulate, tabulate_requirement_message class SeriesConversionMixin: diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 0f4142e91b256..4ab9b041e3135 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -21,10 +21,10 @@ from pyspark.errors import PySparkValueError if TYPE_CHECKING: - from pyspark.testing.connectutils import have_graphviz - - if have_graphviz: + try: import graphviz # type: ignore + except ImportError: + pass class ObservedMetrics(abc.ABC): diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 8a4ec68fda844..40b6a072e9127 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -17,12 +17,9 @@ import unittest -from pyspark.testing.connectutils import ( - should_test_connect, - have_graphviz, - graphviz_requirement_message, -) from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.connectutils import should_test_connect +from pyspark.testing.utils import have_graphviz, graphviz_requirement_message if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py index 3221a408d153d..c37aef5f7c94f 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -18,8 +18,8 @@ import unittest from pyspark.errors import PySparkValueError from pyspark.sql import Row -from pyspark.testing.sqlutils import ( - ReusedSQLTestCase, +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import ( have_plotly, plotly_requirement_message, have_pandas, diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 84a9c2aa01706..fd264c3488823 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -19,8 +19,8 @@ from datetime import datetime from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ( - ReusedSQLTestCase, +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import ( have_plotly, plotly_requirement_message, have_pandas, diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 10e8ce6f69af3..09d3ffb09708f 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -23,29 +23,6 @@ import decimal from typing import Any, Union -tabulate_requirement_message = None -try: - from tabulate import tabulate -except ImportError as e: - # If tabulate requirement is not satisfied, skip related tests. - tabulate_requirement_message = str(e) -have_tabulate = tabulate_requirement_message is None - -matplotlib_requirement_message = None -try: - import matplotlib -except ImportError as e: - # If matplotlib requirement is not satisfied, skip related tests. - matplotlib_requirement_message = str(e) -have_matplotlib = matplotlib_requirement_message is None - -plotly_requirement_message = None -try: - import plotly -except ImportError as e: - # If plotly requirement is not satisfied, skip related tests. - plotly_requirement_message = str(e) -have_plotly = plotly_requirement_message is None try: from pyspark.sql.pandas.utils import require_minimum_pandas_version diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index c833abfb805dc..e5464257422ae 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -22,23 +22,17 @@ import tempfile from contextlib import contextmanager -pandas_requirement_message = None -try: - from pyspark.sql.pandas.utils import require_minimum_pandas_version - - require_minimum_pandas_version() -except ImportError as e: - # If Pandas version requirement is not satisfied, skip related tests. - pandas_requirement_message = str(e) - -pyarrow_requirement_message = None -try: - from pyspark.sql.pandas.utils import require_minimum_pyarrow_version +from pyspark.sql import SparkSession +from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row +from pyspark.testing.utils import ( + ReusedPySparkTestCase, + PySparkErrorTestUtils, + have_pandas, + pandas_requirement_message, + have_pyarrow, + pyarrow_requirement_message, +) - require_minimum_pyarrow_version() -except ImportError as e: - # If Arrow version requirement is not satisfied, skip related tests. - pyarrow_requirement_message = str(e) test_not_compiled_message = None try: @@ -48,21 +42,6 @@ except Exception as e: test_not_compiled_message = str(e) -plotly_requirement_message = None -try: - import plotly -except ImportError as e: - plotly_requirement_message = str(e) -have_plotly = plotly_requirement_message is None - - -from pyspark.sql import SparkSession -from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils - - -have_pandas = pandas_requirement_message is None -have_pyarrow = pyarrow_requirement_message is None test_compiled = test_not_compiled_message is None diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index ca16628fc56f0..1dd15666382f6 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -82,6 +82,39 @@ def have_package(name: str) -> bool: have_plotly = have_package("plotly") plotly_requirement_message = None if have_plotly else "No module named 'plotly'" +have_matplotlib = have_package("matplotlib") +matplotlib_requirement_message = None if have_matplotlib else "No module named 'matplotlib'" + +have_tabulate = have_package("tabulate") +tabulate_requirement_message = None if have_tabulate else "No module named 'tabulate'" + +have_graphviz = have_package("graphviz") +graphviz_requirement_message = None if have_graphviz else "No module named 'graphviz'" + + +pandas_requirement_message = None +try: + from pyspark.sql.pandas.utils import require_minimum_pandas_version + + require_minimum_pandas_version() +except Exception as e: + # If Pandas version requirement is not satisfied, skip related tests. + pandas_requirement_message = str(e) + +have_pandas = pandas_requirement_message is None + + +pyarrow_requirement_message = None +try: + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() +except Exception as e: + # If Arrow version requirement is not satisfied, skip related tests. + pyarrow_requirement_message = str(e) + +have_pyarrow = pyarrow_requirement_message is None + def read_int(b): return struct.unpack("!i", b)[0]