Skip to content

Commit

Permalink
[SPARK-50388][PYTHON][TESTS] Further centralize import checks
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Further centralized import checks:
1, move `have_xxx` from `sqlutils.py/pandasutils.py/xxx` to `utils.py`;
2, but still keep `have_pandas` and `have_pyarrow` in `sqlutils.py`, by importing them from `utils.py`, because there are too many usage places

### Why are the changes needed?
simplify the import checks, e.g. `have_plotly` has been defined in multiple places

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48926 from zhengruifeng/py_dep_2.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Nov 22, 2024
1 parent 5d1f585 commit 190c504
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 97 deletions.
7 changes: 2 additions & 5 deletions python/pyspark/pandas/tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
import pandas as pd

from pyspark import pandas as ps
from pyspark.testing.pandasutils import (
have_tabulate,
PandasOnSparkTestCase,
tabulate_requirement_message,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.utils import have_tabulate, tabulate_requirement_message


# This file contains test cases for 'Serialization / IO / Conversion'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_matplotlib,
matplotlib_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message

if have_matplotlib:
import matplotlib
Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_plotly,
plotly_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_plotly, plotly_requirement_message
from pyspark.pandas.utils import name_like_string

if have_plotly:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/plot/test_series_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyspark import pandas as ps
from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase
from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message
from pyspark.testing.utils import have_plotly, plotly_requirement_message


class SeriesPlotTestsMixin:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.testing.pandasutils import (
have_matplotlib,
matplotlib_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_matplotlib, matplotlib_requirement_message

if have_matplotlib:
import matplotlib
Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.pandas.utils import name_like_string
from pyspark.testing.pandasutils import (
have_plotly,
plotly_requirement_message,
PandasOnSparkTestCase,
TestUtils,
)
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.testing.utils import have_plotly, plotly_requirement_message

if have_plotly:
from plotly import express
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/series/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.pandasutils import have_tabulate, tabulate_requirement_message
from pyspark.testing.utils import have_tabulate, tabulate_requirement_message


class SeriesConversionMixin:
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from pyspark.errors import PySparkValueError

if TYPE_CHECKING:
from pyspark.testing.connectutils import have_graphviz

if have_graphviz:
try:
import graphviz # type: ignore
except ImportError:
pass


class ObservedMetrics(abc.ABC):
Expand Down
7 changes: 2 additions & 5 deletions python/pyspark/sql/tests/connect/test_df_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@

import unittest

from pyspark.testing.connectutils import (
should_test_connect,
have_graphviz,
graphviz_requirement_message,
)
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.connectutils import should_test_connect
from pyspark.testing.utils import have_graphviz, graphviz_requirement_message

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import unittest
from pyspark.errors import PySparkValueError
from pyspark.sql import Row
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_plotly,
plotly_requirement_message,
have_pandas,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from datetime import datetime

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_plotly,
plotly_requirement_message,
have_pandas,
Expand Down
23 changes: 0 additions & 23 deletions python/pyspark/testing/pandasutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,6 @@
import decimal
from typing import Any, Union

tabulate_requirement_message = None
try:
from tabulate import tabulate
except ImportError as e:
# If tabulate requirement is not satisfied, skip related tests.
tabulate_requirement_message = str(e)
have_tabulate = tabulate_requirement_message is None

matplotlib_requirement_message = None
try:
import matplotlib
except ImportError as e:
# If matplotlib requirement is not satisfied, skip related tests.
matplotlib_requirement_message = str(e)
have_matplotlib = matplotlib_requirement_message is None

plotly_requirement_message = None
try:
import plotly
except ImportError as e:
# If plotly requirement is not satisfied, skip related tests.
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None

try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand Down
41 changes: 10 additions & 31 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,17 @@
import tempfile
from contextlib import contextmanager

pandas_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()
except ImportError as e:
# If Pandas version requirement is not satisfied, skip related tests.
pandas_requirement_message = str(e)

pyarrow_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import (
ReusedPySparkTestCase,
PySparkErrorTestUtils,
have_pandas,
pandas_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
)

require_minimum_pyarrow_version()
except ImportError as e:
# If Arrow version requirement is not satisfied, skip related tests.
pyarrow_requirement_message = str(e)

test_not_compiled_message = None
try:
Expand All @@ -48,21 +42,6 @@
except Exception as e:
test_not_compiled_message = str(e)

plotly_requirement_message = None
try:
import plotly
except ImportError as e:
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None


from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils


have_pandas = pandas_requirement_message is None
have_pyarrow = pyarrow_requirement_message is None
test_compiled = test_not_compiled_message is None


Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def have_package(name: str) -> bool:
have_plotly = have_package("plotly")
plotly_requirement_message = None if have_plotly else "No module named 'plotly'"

have_matplotlib = have_package("matplotlib")
matplotlib_requirement_message = None if have_matplotlib else "No module named 'matplotlib'"

have_tabulate = have_package("tabulate")
tabulate_requirement_message = None if have_tabulate else "No module named 'tabulate'"

have_graphviz = have_package("graphviz")
graphviz_requirement_message = None if have_graphviz else "No module named 'graphviz'"


pandas_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()
except Exception as e:
# If Pandas version requirement is not satisfied, skip related tests.
pandas_requirement_message = str(e)

have_pandas = pandas_requirement_message is None


pyarrow_requirement_message = None
try:
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
except Exception as e:
# If Arrow version requirement is not satisfied, skip related tests.
pyarrow_requirement_message = str(e)

have_pyarrow = pyarrow_requirement_message is None


def read_int(b):
return struct.unpack("!i", b)[0]
Expand Down

0 comments on commit 190c504

Please sign in to comment.