-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-20040][ML][python] pyspark wrapper for ChiSquareTest #17421
Changes from 4 commits
1c6acd7
9f177c6
aa40d58
3e7163c
e79f968
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark import since, SparkContext | ||
from pyspark.ml.common import _java2py, _py2java | ||
from pyspark.ml.wrapper import _jvm | ||
|
||
|
||
class ChiSquareTest(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark as Experimental (Search for other example of this) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, we put the triple-quotes on their own line elsewhere in pyspark |
||
""" | ||
.. note:: Experimental | ||
|
||
Conduct Pearson's independence test for every feature against the label. For each feature, | ||
the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared | ||
statistic is computed. All label and feature values must be categorical. | ||
|
||
The null hypothesis is that the occurrence of the outcomes is statistically independent. | ||
|
||
:param dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy param text from the Scala doc, unless there's a need to customize it for Python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for the return value text |
||
DataFrame of categorical labels and categorical features. | ||
Real-valued features will be treated as categorical for each distinct value. | ||
:param featuresCol: | ||
Name of features column in dataset, of type `Vector` (`VectorUDT`). | ||
:param labelCol: | ||
Name of label column in dataset, of any numerical type. | ||
:return: | ||
DataFrame containing the test result for every feature against the label. | ||
This DataFrame will contain a single Row with the following fields: | ||
- `pValues: Vector` | ||
- `degreesOfFreedom: Array[Int]` | ||
- `statistics: Vector` | ||
Each of these fields has one value per feature. | ||
|
||
>>> from pyspark.ml.linalg import Vectors | ||
>>> from pyspark.ml.stat import ChiSquareTest | ||
>>> dataset = [[0, Vectors.dense([0, 0, 1])], | ||
... [0, Vectors.dense([1, 0, 1])], | ||
... [1, Vectors.dense([2, 1, 1])], | ||
... [1, Vectors.dense([3, 1, 1])]] | ||
>>> dataset = spark.createDataFrame(dataset, ["label", "features"]) | ||
>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') | ||
>>> chiSqResult.select("degreesOfFreedom").collect()[0] | ||
Row(degreesOfFreedom=[3, 1, 0]) | ||
|
||
.. versionadded:: 2.2.0 | ||
|
||
""" | ||
@staticmethod | ||
@since("2.2.0") | ||
def test(dataset, featuresCol, labelCol): | ||
""" | ||
Perform a Pearson's independence test using dataset. | ||
""" | ||
sc = SparkContext._active_spark_context | ||
javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest | ||
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] | ||
return _java2py(sc, javaTestObj.test(*args)) | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
import pyspark.ml.stat | ||
from pyspark.sql import SparkSession | ||
|
||
globs = pyspark.ml.stat.__dict__.copy() | ||
# The small batch size here ensures that we see multiple batches, | ||
# even in these small test examples: | ||
spark = SparkSession.builder \ | ||
.master("local[2]") \ | ||
.appName("ml.stat tests") \ | ||
.getOrCreate() | ||
sc = spark.sparkContext | ||
globs['sc'] = sc | ||
globs['spark'] = spark | ||
import tempfile | ||
|
||
temp_path = tempfile.mkdtemp() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this test is using the temp path? |
||
globs['temp_path'] = temp_path | ||
try: | ||
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) | ||
spark.stop() | ||
finally: | ||
from shutil import rmtree | ||
|
||
try: | ||
rmtree(temp_path) | ||
except OSError: | ||
pass | ||
if failure_count: | ||
exit(-1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,9 +41,7 @@ | |
import tempfile | ||
import array as pyarray | ||
import numpy as np | ||
from numpy import ( | ||
abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) | ||
from numpy import sum as array_sum | ||
from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for cleaning up the numpy imports :) +1 |
||
import inspect | ||
|
||
from pyspark import keyword_only, SparkContext | ||
|
@@ -54,20 +52,19 @@ | |
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator | ||
from pyspark.ml.feature import * | ||
from pyspark.ml.fpm import FPGrowth, FPGrowthModel | ||
from pyspark.ml.linalg import ( | ||
DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, | ||
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) | ||
from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ | ||
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors | ||
from pyspark.ml.param import Param, Params, TypeConverters | ||
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed | ||
from pyspark.ml.recommendation import ALS | ||
from pyspark.ml.regression import ( | ||
DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) | ||
from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ | ||
LinearRegression | ||
from pyspark.ml.stat import ChiSquareTest | ||
from pyspark.ml.tuning import * | ||
from pyspark.ml.wrapper import JavaParams, JavaWrapper | ||
from pyspark.serializers import PickleSerializer | ||
from pyspark.sql import DataFrame, Row, SparkSession | ||
from pyspark.sql.functions import rand | ||
from pyspark.sql.utils import IllegalArgumentException | ||
from pyspark.storagelevel import * | ||
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase | ||
|
||
|
@@ -1741,6 +1738,22 @@ def test_new_java_array(self): | |
self.assertEqual(_java2py(self.sc, java_array), []) | ||
|
||
|
||
class ChiSquareTestTests(SparkSessionTestCase): | ||
|
||
def test_chisquaretest(self): | ||
data = [[0, Vectors.dense([0, 1, 2])], | ||
[1, Vectors.dense([1, 1, 1])], | ||
[2, Vectors.dense([2, 1, 0])]] | ||
df = self.spark.createDataFrame(data, ['label', 'feat']) | ||
res = ChiSquareTest.test(df, 'feat', 'label') | ||
# This line is hitting the collect bug described in #17218, commented for now. | ||
# pValues = res.select("degreesOfFreedom").collect()) | ||
self.assertIsInstance(res, DataFrame) | ||
fieldNames = set(field.name for field in res.schema.fields) | ||
expectedFields = ["pValues", "degreesOfFreedom", "statistics"] | ||
self.assertTrue(all(field in fieldNames for field in expectedFields)) | ||
|
||
|
||
if __name__ == "__main__": | ||
from pyspark.ml.tests import * | ||
if xmlrunner: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just took it out in 314cf51 , but since this is adding back in ml.stat we also need to update setup.py (you might need to update your branch from the latest master to see this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@holdenk thanks for catching that, should be fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, do we need to update setup.py? This is creating a module, not a package, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sub-modules aren't automatically packaged so we do need to explicitly add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jkbradley, I reverted setup.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@holdenk If we need to add pyspark.ml.stat to setup.py, then why are we not adding the other analogous modules: pyspark.ml.{classification, clustering, regression,...}?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yah sorry, its anything which is a new sub-directory and when I was reading this PR yesterday I thought this was a new directory, but looking it today that isn't the case, sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, no problem, I just wanted to check.