Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50388][PYTHON][TESTS] Further centralize import checks #48926

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions python/pyspark/pandas/tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
import pandas as pd

from pyspark import pandas as ps
from pyspark.testing.pandasutils import (
have_tabulate,
PandasOnSparkTestCase,
tabulate_requirement_message,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.utils import have_tabulate, tabulate_requirement_message


# This file contains test cases for 'Serialization / IO / Conversion'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_matplotlib,
matplotlib_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message

if have_matplotlib:
import matplotlib
Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_plotly,
plotly_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_plotly, plotly_requirement_message
from pyspark.pandas.utils import name_like_string

if have_plotly:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/plot/test_series_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyspark import pandas as ps
from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase
from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message
from pyspark.testing.utils import have_plotly, plotly_requirement_message


class SeriesPlotTestsMixin:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_matplotlib,
matplotlib_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message

if have_matplotlib:
import matplotlib
Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.pandas.utils import name_like_string
from pyspark.testing.pandasutils import (
have_plotly,
plotly_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_plotly, plotly_requirement_message

if have_plotly:
from plotly import express
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/series/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.pandasutils import have_tabulate, tabulate_requirement_message
from pyspark.testing.utils import have_tabulate, tabulate_requirement_message


class SeriesConversionMixin:
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from pyspark.errors import PySparkValueError

if TYPE_CHECKING:
from pyspark.testing.connectutils import have_graphviz

if have_graphviz:
try:
import graphviz # type: ignore
except ImportError:
pass


class ObservedMetrics(abc.ABC):
Expand Down
7 changes: 2 additions & 5 deletions python/pyspark/sql/tests/connect/test_df_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@

import unittest

from pyspark.testing.connectutils import (
should_test_connect,
have_graphviz,
graphviz_requirement_message,
)
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.connectutils import should_test_connect
from pyspark.testing.utils import have_graphviz, graphviz_requirement_message

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import unittest
from pyspark.errors import PySparkValueError
from pyspark.sql import Row
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_plotly,
plotly_requirement_message,
have_pandas,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from datetime import datetime

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_plotly,
plotly_requirement_message,
have_pandas,
Expand Down
23 changes: 0 additions & 23 deletions python/pyspark/testing/pandasutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,6 @@
import decimal
from typing import Any, Union

tabulate_requirement_message = None
try:
from tabulate import tabulate
except ImportError as e:
# If tabulate requirement is not satisfied, skip related tests.
tabulate_requirement_message = str(e)
have_tabulate = tabulate_requirement_message is None

matplotlib_requirement_message = None
try:
import matplotlib
except ImportError as e:
# If matplotlib requirement is not satisfied, skip related tests.
matplotlib_requirement_message = str(e)
have_matplotlib = matplotlib_requirement_message is None

plotly_requirement_message = None
try:
import plotly
except ImportError as e:
# If plotly requirement is not satisfied, skip related tests.
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None

try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand Down
41 changes: 10 additions & 31 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,17 @@
import tempfile
from contextlib import contextmanager

pandas_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()
except ImportError as e:
# If Pandas version requirement is not satisfied, skip related tests.
pandas_requirement_message = str(e)

pyarrow_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import (
ReusedPySparkTestCase,
PySparkErrorTestUtils,
have_pandas,
pandas_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)

require_minimum_pyarrow_version()
except ImportError as e:
# If Arrow version requirement is not satisfied, skip related tests.
pyarrow_requirement_message = str(e)

test_not_compiled_message = None
try:
Expand All @@ -48,21 +42,6 @@
except Exception as e:
test_not_compiled_message = str(e)

plotly_requirement_message = None
try:
import plotly
except ImportError as e:
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None


from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils


have_pandas = pandas_requirement_message is None
have_pyarrow = pyarrow_requirement_message is None
test_compiled = test_not_compiled_message is None


Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def have_package(name: str) -> bool:
have_plotly = have_package("plotly")
plotly_requirement_message = None if have_plotly else "No module named 'plotly'"

have_matplotlib = have_package("matplotlib")
matplotlib_requirement_message = None if have_matplotlib else "No module named 'matplotlib'"

have_tabulate = have_package("tabulate")
tabulate_requirement_message = None if have_tabulate else "No module named 'tabulate'"

have_graphviz = have_package("graphviz")
graphviz_requirement_message = None if have_graphviz else "No module named 'graphviz'"


pandas_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()
except Exception as e:
# If Pandas version requirement is not satisfied, skip related tests.
pandas_requirement_message = str(e)

have_pandas = pandas_requirement_message is None


pyarrow_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
except Exception as e:
# If Arrow version requirement is not satisfied, skip related tests.
pyarrow_requirement_message = str(e)

have_pyarrow = pyarrow_requirement_message is None


def read_int(b):
return struct.unpack("!i", b)[0]
Expand Down