diff --git a/.github/workflows/test-spark-connect.yml b/.github/workflows/test-spark-connect.yml new file mode 100644 index 0000000..675ec1d --- /dev/null +++ b/.github/workflows/test-spark-connect.yml @@ -0,0 +1,33 @@ +name: Main +on: [push, pull_request] +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + PYTHON_VERSION: ["3.11", "3.12"] + JOBLIB_VERSION: ["1.3.0", "1.4.2"] + PIN_MODE: [false, true] + PYSPARK_VERSION: ["4.0.0.dev1"] + name: Run test with spark connect ${{ matrix.PYSPARK_VERSION }}, pin_mode ${{ matrix.PIN_MODE }}, python ${{ matrix.PYTHON_VERSION }}, joblib ${{ matrix.JOBLIB_VERSION }} + steps: + - uses: actions/checkout@v3 + - name: Setup python ${{ matrix.PYTHON_VERSION }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.PYTHON_VERSION }} + architecture: x64 + - name: Install python packages + run: | + pip install setuptools joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint + pip install 'numpy==1.26.4' 'pyarrow==12.0.1' 'pandas<=2.0.3' + # Add Python deps for Spark Connect. + pip install 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3' 'googleapis-common-protos==1.56.4' + pip install "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" + - name: Run pylint + run: | + ./run-pylint.sh + - name: Run test suites + run: | + PYSPARK_PIN_THREAD=${{ matrix.PIN_MODE }} ./run-tests.sh diff --git a/test/test_spark.py b/test/test_spark.py index 9c78257..238a5cb 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -39,21 +39,7 @@ register_spark() -class TestSparkCluster(unittest.TestCase): - spark = None - - @classmethod - def setup_class(cls): - cls.spark = ( - SparkSession.builder.master("local-cluster[1, 2, 1024]") - .config("spark.task.cpus", "1") - .config("spark.task.maxFailures", "1") - .getOrCreate() - ) - - @classmethod - def teardown_class(cls): - cls.spark.stop() +class JoblibsparkTest: def test_simple(self): def inc(x): @@ -117,6 +103,18 @@ def test_fn(x): assert len(os.listdir(tmp_dir)) == 0 +class TestSparkCluster(JoblibsparkTest, unittest.TestCase): + def setUp(self): + self.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) + + def tearDown(self): + self.spark.stop() + @unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0), "Resource group is only supported since spark 3.4.0") class TestGPUSparkCluster(unittest.TestCase): diff --git a/test/test_spark_connect.py b/test/test_spark_connect.py new file mode 100644 index 0000000..944090a --- /dev/null +++ b/test/test_spark_connect.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql import SparkSession + +from test.test_spark import JoblibsparkTest + +class TestsOnSparkConnect(JoblibsparkTest, unittest.TestCase): + + def setUp(self) -> None: + self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + + def tearDown(self) -> None: + self.spark.stop()