Skip to content

Commit

Permalink
[SPARK-50942][ML][PYTHON][CONNECT] Support ChiSquareTest on Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support `ChiSquareTest` on Connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49771 from zhengruifeng/ml_connect_chi_square_test.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 3, 2025
1 parent bff86f1 commit d29648d
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 11 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_parity_pipeline",
"pyspark.ml.tests.connect.test_parity_tuning",
"pyspark.ml.tests.connect.test_parity_ovr",
"pyspark.ml.tests.connect.test_parity_stat",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ org.apache.spark.ml.recommendation.ALSModel
org.apache.spark.ml.fpm.FPGrowthModel
org.apache.spark.ml.fpm.PrefixSpanWrapper

# stat
org.apache.spark.ml.stat.ChiSquareTestWrapper

# feature
org.apache.spark.ml.feature.RFormulaModel
org.apache.spark.ml.feature.ImputerModel
Expand Down
45 changes: 43 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
package org.apache.spark.ml.stat

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.test.{ChiSqTest => OldChiSqTest}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


/**
Expand Down Expand Up @@ -101,3 +105,40 @@ object ChiSquareTest {
}
}
}

/**
* [[ChiSquareTest]] is not an Estimator/Transformer and thus needs to be wrapped in a wrapper
* to be compatible with Spark Connect.
*/
private[spark] class ChiSquareTestWrapper(override val uid: String)
extends Transformer with HasFeaturesCol with HasLabelCol {

val flatten = new BooleanParam(this, "flatten",
"If false, the returned DataFrame contains only a single Row, otherwise, one row per feature.")

setDefault(flatten -> false)

def this() = this(Identifiable.randomUID("ChiSquareTestWrapper"))

override def transformSchema(schema: StructType): StructType = {
if ($(flatten)) {
new StructType()
.add("featureIndex", IntegerType, nullable = false)
.add("pValue", DoubleType, nullable = false)
.add("degreesOfFreedom", IntegerType, nullable = false)
.add("statistic", DoubleType, nullable = false)
} else {
new StructType()
.add("pValues", new VectorUDT, nullable = false)
.add("degreesOfFreedom", ArrayType(IntegerType, containsNull = false), nullable = false)
.add("statistics", new VectorUDT, nullable = false)
}
}

override def transform(dataset: Dataset[_]): DataFrame = {
ChiSquareTest.test(dataset.toDF(), $(featuresCol), $(labelCol), $(flatten))
}

override def copy(extra: ParamMap): ChiSquareTestWrapper = defaultCopy(extra)
}

9 changes: 9 additions & 0 deletions python/pyspark/ml/connect/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,12 @@ def serialize_ml_params(instance: "Params", client: "SparkConnectClient") -> pb2
k.name: serialize_param(v, client) for k, v in instance._paramMap.items()
}
return pb2.MlParams(params=params)


def serialize_ml_params_values(
values: Dict[str, Any], client: "SparkConnectClient"
) -> pb2.MlParams:
params: Mapping[str, pb2.Expression.Literal] = {
k: serialize_param(v, client) for k, v in values.items()
}
return pb2.MlParams(params=params)
28 changes: 22 additions & 6 deletions python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,30 @@ def test(
>>> row[0].statistic
4.0
"""
from pyspark.core.context import SparkContext
from pyspark.sql.utils import is_remote

sc = SparkContext._active_spark_context
assert sc is not None
if is_remote():
from pyspark.ml.wrapper import JavaTransformer
from pyspark.ml.connect.serialize import serialize_ml_params_values

instance = JavaTransformer()
instance._java_obj = "org.apache.spark.ml.stat.ChiSquareTestWrapper"
serialized_ml_params = serialize_ml_params_values(
{"featuresCol": featuresCol, "labelCol": labelCol, "flatten": flatten},
dataset.sparkSession.client, # type: ignore[arg-type,operator]
)
instance._serialized_ml_params = serialized_ml_params # type: ignore[attr-defined]
return instance.transform(dataset)

else:
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None

javaTestObj = getattr(_jvm(), "org.apache.spark.ml.stat.ChiSquareTest")
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)]
return _java2py(sc, javaTestObj.test(*args))
javaTestObj = getattr(_jvm(), "org.apache.spark.ml.stat.ChiSquareTest")
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)]
return _java2py(sc, javaTestObj.test(*args))


class Correlation:
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.ml.tests.test_stat import StatTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class StatParityTests(StatTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_stat import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
18 changes: 15 additions & 3 deletions python/pyspark/ml/tests/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import ChiSquareTest
from pyspark.sql import DataFrame
from pyspark.testing.mlutils import SparkSessionTestCase
from pyspark.testing.sqlutils import ReusedSQLTestCase


class ChiSquareTestTests(SparkSessionTestCase):
class StatTestsMixin:
def test_chisquaretest(self):
spark = self.spark
data = [
[0, Vectors.dense([0, 1, 2])],
[1, Vectors.dense([1, 1, 1])],
[2, Vectors.dense([2, 1, 0])],
]
df = self.spark.createDataFrame(data, ["label", "feat"])
df = spark.createDataFrame(data, ["label", "feat"])
res = ChiSquareTest.test(df, "feat", "label")
# This line is hitting the collect bug described in #17218, commented for now.
# pValues = res.select("degreesOfFreedom").collect())
Expand All @@ -39,6 +40,17 @@ def test_chisquaretest(self):
expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
self.assertTrue(all(field in fieldNames for field in expectedFields))

self.assertEqual(res.columns, ["pValues", "degreesOfFreedom", "statistics"])
self.assertEqual(res.count(), 1)

res2 = ChiSquareTest.test(df, "feat", "label", True)
self.assertEqual(res2.columns, ["featureIndex", "pValue", "degreesOfFreedom", "statistic"])
self.assertEqual(res2.count(), 3)


class StatTests(StatTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.test_stat import * # noqa: F401
Expand Down

0 comments on commit d29648d

Please sign in to comment.