Skip to content

Commit

Permalink
[SPARK-50517][PYTHON][TESTS] Group arrow function related tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Group arrow function related tests

### Why are the changes needed?
test clean up

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49104 from zhengruifeng/group_connect_test_arrow.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Dec 9, 2024
1 parent 6c2e87a commit 85d92d7
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 25 deletions.
20 changes: 10 additions & 10 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,6 @@ def __hash__(self):
"pyspark.sql.observation",
"pyspark.sql.tvf",
# unittests
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_arrow_cogrouped_map",
"pyspark.sql.tests.test_arrow_grouped_map",
"pyspark.sql.tests.test_arrow_python_udf",
"pyspark.sql.tests.test_catalog",
"pyspark.sql.tests.test_column",
"pyspark.sql.tests.test_conf",
Expand All @@ -522,12 +518,16 @@ def __hash__(self):
"pyspark.sql.tests.test_functions",
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_sql",
"pyspark.sql.tests.arrow.test_arrow",
"pyspark.sql.tests.arrow.test_arrow_map",
"pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
"pyspark.sql.tests.arrow.test_arrow_python_udf",
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
"pyspark.sql.tests.pandas.test_pandas_grouped_map",
"pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.pandas.test_pandas_map",
"pyspark.sql.tests.pandas.test_pandas_transform_with_state",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.pandas.test_pandas_udf",
"pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg",
"pyspark.sql.tests.pandas.test_pandas_udf_scalar",
Expand Down Expand Up @@ -1029,8 +1029,6 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_readwriter",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
"pyspark.sql.tests.connect.test_parity_arrow",
"pyspark.sql.tests.connect.test_parity_arrow_python_udf",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
Expand All @@ -1054,9 +1052,6 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_memory_profiler",
"pyspark.sql.tests.connect.test_parity_udtf",
"pyspark.sql.tests.connect.test_parity_tvf",
"pyspark.sql.tests.connect.test_parity_arrow_map",
"pyspark.sql.tests.connect.test_parity_arrow_grouped_map",
"pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
"pyspark.sql.tests.connect.test_parity_python_datasource",
"pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
"pyspark.sql.tests.connect.test_parity_frame_plot",
Expand All @@ -1073,6 +1068,11 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_resources",
"pyspark.sql.tests.connect.shell.test_progress",
"pyspark.sql.tests.connect.test_df_debug",
"pyspark.sql.tests.connect.arrow.test_parity_arrow",
"pyspark.sql.tests.connect.arrow.test_parity_arrow_map",
"pyspark.sql.tests.connect.arrow.test_parity_arrow_grouped_map",
"pyspark.sql.tests.connect.arrow.test_parity_arrow_cogrouped_map",
"pyspark.sql.tests.connect.arrow.test_parity_arrow_python_udf",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_map",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map_with_state",
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
Expand Up @@ -1778,7 +1778,7 @@ def conf(cls):


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow import * # noqa: F401
from pyspark.sql.tests.arrow.test_arrow import * # noqa: F401

try:
import xmlrunner # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def tearDownClass(cls):


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_cogrouped_map import * # noqa: F401
from pyspark.sql.tests.arrow.test_arrow_cogrouped_map import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def tearDownClass(cls):


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_grouped_map import * # noqa: F401
from pyspark.sql.tests.arrow.test_arrow_grouped_map import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def tearDownClass(cls):


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_map import * # noqa: F401
from pyspark.sql.tests.arrow.test_arrow_map import * # noqa: F401

try:
import xmlrunner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def tearDownClass(cls):


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401
from pyspark.sql.tests.arrow.test_arrow_python_udf import * # noqa: F401

try:
import xmlrunner
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests/connect/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.sql.tests.test_arrow import ArrowTestsMixin
from pyspark.sql.tests.arrow.test_arrow import ArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils

Expand Down Expand Up @@ -139,7 +139,7 @@ def test_create_dataframe_namedtuples(self):


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
from pyspark.sql.tests.connect.arrow.test_parity_arrow import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.sql.tests.test_arrow_cogrouped_map import CogroupedMapInArrowTestsMixin
from pyspark.sql.tests.arrow.test_arrow_cogrouped_map import CogroupedMapInArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


Expand All @@ -26,7 +26,7 @@ class CogroupedMapInArrowParityTests(CogroupedMapInArrowTestsMixin, ReusedConnec


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map import * # noqa: F401
from pyspark.sql.tests.connect.arrow.test_parity_arrow_cogrouped_map import * # noqa: F401

try:
import xmlrunner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.sql.tests.test_arrow_grouped_map import GroupedMapInArrowTestsMixin
from pyspark.sql.tests.arrow.test_arrow_grouped_map import GroupedMapInArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


Expand All @@ -26,7 +26,7 @@ class GroupedApplyInArrowParityTests(GroupedMapInArrowTestsMixin, ReusedConnectT


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow_grouped_map import * # noqa: F401
from pyspark.sql.tests.connect.arrow.test_parity_arrow_grouped_map import * # noqa: F401

try:
import xmlrunner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.sql.tests.test_arrow_map import MapInArrowTestsMixin
from pyspark.sql.tests.arrow.test_arrow_map import MapInArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


Expand All @@ -26,7 +26,7 @@ class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase):


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow_map import * # noqa: F401
from pyspark.sql.tests.connect.arrow.test_parity_arrow_map import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin
from pyspark.sql.tests.arrow.test_arrow_python_udf import PythonUDFArrowTestsMixin


class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin):
Expand All @@ -35,7 +35,7 @@ def tearDownClass(cls):

if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401
from pyspark.sql.tests.connect.arrow.test_parity_arrow_python_udf import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
Expand Down

0 comments on commit 85d92d7

Please sign in to comment.