Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17315][SparkR] Kolmogorov-Smirnov test SparkR wrapper #14881

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ exportMethods("glm",
"spark.perplexity",
"spark.isoreg",
"spark.gaussianMixture",
"spark.als")
"spark.als",
"spark.kstest")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down Expand Up @@ -335,7 +336,8 @@ export("as.DataFrame",
"tables",
"uncacheTable",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml")
"read.ml",
"print.summary.KSTest")

export("structField",
"structField.jobj",
Expand All @@ -359,6 +361,7 @@ S3method(print, jobj)
S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
S3method(structField, character)
S3method(structField, jobj)
S3method(structType, jobj)
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1375,3 +1375,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml")
#' @rdname spark.als
#' @export
setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })

#' @rdname spark.kstest
#' @export
setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") })
108 changes: 108 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @note ALSModel since 2.1.0
setClass("ALSModel", representation(jobj = "jobj"))

#' S4 class that represents an KSTest
#'
#' @param jobj a Java object reference to the backing Scala KSTestWrapper
#' @export
#' @note KSTest since 2.1.0
setClass("KSTest", representation(jobj = "jobj"))

#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
Expand Down Expand Up @@ -1308,3 +1315,104 @@ setMethod("write.ml", signature(object = "ALSModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' (One-Sample) Kolmogorov-Smirnov Test
#'
#' @description
#' \code{spark.kstest} Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
#' continuous distribution.
#'
#' By comparing the largest difference between the empirical cumulative
#' distribution of the sample data and the theoretical distribution we can provide a test for the
#' the null hypothesis that the sample data comes from that theoretical distribution.
#'
#' Users can call \code{summary} to obtain a summary of the test, and \code{print.summary.KSTest}
#' to print out a summary result.
#'
#' @details
#' For more details, see
#' \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{
#' MLlib: Hypothesis Testing}.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put this in @Seealso? That seems to be the typical way to add link in our doc

#'
#' @param data a SparkDataFrame of user data.
#' @param testCol column name where the test data is from. It should be a column of double type.
#' @param nullHypothesis name of the theoretical distribution tested against. Currently only
#' \code{"norm"} for normal distribution is supported.
#' @param distParams parameters(s) of the distribution. For \code{nullHypothesis = "norm"},
#' we can provide as a vector the mean and standard deviation of
#' the distribution. If none is provided, then standard normal will be used.
#' If only one is provided, then the standard deviation will be set to be one.
#' @param ... additional argument(s) passed to the method.
#' @return \code{spark.kstest} returns a test result object.
#' @rdname spark.kstest
#' @aliases spark.kstest,SparkDataFrame-method
#' @name spark.kstest
#' @export
#' @examples
#' \dontrun{
#' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25))
#' df <- createDataFrame(data)
#' test <- spark.ktest(df, "test", "norm", c(0, 1))
#'
#' # get a summary of the test result
#' testSummary <- summary(test)
#' testSummary
#'
#' # print out the summary in an organized way
#' print.summary.KSTest(test)
#' }
#' @note spark.kstest since 2.1.0
setMethod("spark.kstest", signature(data = "SparkDataFrame"),
function(data, testCol = "test", nullHypothesis = c("norm"), distParams = c(0, 1)) {
tryCatch(match.arg(nullHypothesis),
error = function(e) {
msg <- paste("Distribution", nullHypothesis, "is not supported.")
stop(msg)
})
if (nullHypothesis == "norm") {
Copy link
Member

@felixcheung felixcheung Aug 31, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it stop if nullHypothesis is not norm?
ah - it's already there above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this intentionally in case we add more distributions in the future.

distParams <- as.numeric(distParams)
mu <- ifelse(length(distParams) < 1, 0, distParams[1])
sigma <- ifelse(length(distParams) < 2, 1, distParams[2])
jobj <- callJStatic("org.apache.spark.ml.r.KSTestWrapper",
"test", data@sdf, testCol, nullHypothesis,
as.array(c(mu, sigma)))
new("KSTest", jobj = jobj)
}
})

# Get the summary of Kolmogorov-Smirnov (KS) Test.
#' @param object test result object of KS.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seem like we usually call out \code{spark.kstest} - I think because summary rd is documenting a bunch of functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The earlier comment was about this line in summary actually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the summary method is in the spark.kstest doc? summary rd only includes methods for SparkDataFrame.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, it's your call - I thought it would be better to be consistent with all other summary methods but it wasn't clear why that was done initially.

#' @return \code{summary} returns a list containing the p-value, test statistic computed for the
#' test, the null hypothesis with its parameters tested against
#' and degrees of freedom of the test.
#' @rdname spark.kstest
#' @aliases summary,KSTest-method
#' @export
#' @note summary(KSTest) since 2.1.0
setMethod("summary", signature(object = "KSTest"),
function(object) {
jobj <- object@jobj
pValue <- callJMethod(jobj, "pValue")
statistic <- callJMethod(jobj, "statistic")
nullHypothesis <- callJMethod(jobj, "nullHypothesis")
distName <- callJMethod(jobj, "distName")
distParams <- unlist(callJMethod(jobj, "distParams"))
degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom")

list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis,
nullHypothesis.name = distName, nullHypothesis.parameters = distParams,
degreesOfFreedom = degreesOfFreedom)
})

# Prints the summary of GeneralizedLinearRegressionModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> KSTest


#' @rdname spark.kstest
#' @param x test result object of KS.
#' @export
#' @note print.summary.KSTest since 2.1.0
print.summary.KSTest <- function(x, ...) {
jobj <- x@jobj
summaryStr <- callJMethod(jobj, "summary")
cat(summaryStr)
invisible(summaryStr)
}
20 changes: 20 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,24 @@ test_that("spark.als", {
unlink(modelPath)
})

test_that("spark.kstest", {
data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5))
df <- createDataFrame(data)
testResult <- spark.kstest(df, "test", "norm")
stats <- summary(testResult)

rStats <- ks.test(data$test, "pnorm", alternative = "two.sided")

expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4)
expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4)

testResult <- spark.kstest(df, "test", "norm", -0.5)
stats <- summary(testResult)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test for print.summary too?


rStats <- ks.test(data$test, "pnorm", -0.5, 1, alternative = "two.sided")

expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4)
expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4)
})

sparkR.session.stop()
57 changes: 57 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/r/KSTestWrapper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.mllib.stat.Statistics.kolmogorovSmirnovTest
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult
import org.apache.spark.sql.{DataFrame, Row}

private[r] class KSTestWrapper private (
val testResult: KolmogorovSmirnovTestResult,
val distName: String,
val distParams: Array[Double]) {

lazy val pValue = testResult.pValue

lazy val statistic = testResult.statistic

lazy val nullHypothesis = testResult.nullHypothesis

lazy val degreesOfFreedom = testResult.degreesOfFreedom

def summary: String = testResult.toString
}

private[r] object KSTestWrapper {

def test(
data: DataFrame,
featureName: String,
distName: String,
distParams: Array[Double]): KSTestWrapper = {

val rddData = data.select(featureName).rdd.map {
case Row(feature: Double) => feature
}

val ksTestResult = kolmogorovSmirnovTest(rddData, distName, distParams : _*)

new KSTestWrapper(ksTestResult, distName, distParams)
}
}