Skip to content

Commit

Permalink
[SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRe…
Browse files Browse the repository at this point in the history
…gression in SparkR

## What changes were proposed in this pull request?
This PR continues the work in #11447, we implemented the wrapper of ```AFTSurvivalRegression``` named ```survreg``` in SparkR.

## How was this patch tested?
Test against output from R package survival's survreg.

cc mengxr felixcheung

Close #11447

Author: Yanbo Liang <[email protected]>

Closes #11932 from yanboliang/spark-13010-new.
  • Loading branch information
yanboliang authored and mengxr committed Mar 25, 2016
1 parent 05f652d commit 13cbb2d
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 2 deletions.
3 changes: 2 additions & 1 deletion R/pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Depends:
methods,
Suggests:
testthat,
e1071
e1071,
survival
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
Expand Down
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ exportMethods("glm",
"summary",
"kmeans",
"fitted",
"naiveBayes")
"naiveBayes",
"survreg")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1179,3 +1179,7 @@ setGeneric("fitted")
#' @rdname naiveBayes
#' @export
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })

#' @rdname survreg
#' @export
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
75 changes: 75 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @export
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' @title S4 class that represents a AFTSurvivalRegressionModel
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
#' @export
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
Expand Down Expand Up @@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
formula, data@sdf, laplace)
return(new("NaiveBayesModel", jobj = jobj))
})

#' Fit an accelerated failure time (AFT) survival regression model.
#'
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
#' @param data DataFrame for training.
#' @return a fitted AFT survival regression model
#' @rdname survreg
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, ovarian)
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
#' }
setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
function(formula, data, ...) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
return(new("AFTSurvivalRegressionModel", jobj = jobj))
})

#' Get the summary of an AFT survival regression model
#'
#' Returns the summary of an AFT survival regression model produced by survreg(),
#' similarly to R's summary().
#'
#' @param object a fitted AFT survival regression model
#' @return coefficients the model's coefficients, intercept and log(scale).
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
#' summary(model)
#' }
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Value")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})

#' Make predictions from an AFT survival regression model
#'
#' Make predictions from a model produced by survreg(), similarly to R package survival's predict.
#'
#' @param object A fitted AFT survival regression model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
#' \dontrun{
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#' }
setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
49 changes: 49 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,52 @@ test_that("naiveBayes", {
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})

test_that("survreg", {
# R code to reproduce the result.
#
#' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
#' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
#' library(survival)
#' model <- survreg(Surv(time, status) ~ x + sex, rData)
#' summary(model)
#' predict(model, data)
#
# -- output of 'summary(model)'
#
# Value Std. Error z p
# (Intercept) 1.315 0.270 4.88 1.07e-06
# x -0.190 0.173 -1.10 2.72e-01
# sex -0.253 0.329 -0.77 4.42e-01
# Log(scale) -1.160 0.396 -2.93 3.41e-03
#
# -- output of 'predict(model, data)'
#
# 1 2 3 4 5 6 7
# 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
#
data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
model <- survreg(Surv(time, status) ~ x + sex, df)
stats <- summary(model)
coefs <- as.vector(stats$coefficients[, 1])
rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
expect_equal(coefs, rCoefs, tolerance = 1e-4)
expect_true(all(
rownames(stats$coefficients) ==
c("(Intercept)", "x", "sex", "Log(scale)")))
p <- collect(select(predict(model, df), "prediction"))
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
2.390146, 2.891269, 2.891269), tolerance = 1e-4)

# Test survival::survreg
if (requireNamespace("survival", quietly = TRUE)) {
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
expect_that(
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
not(throws_error()))
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
}
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
import org.apache.spark.sql.DataFrame

private[r] class AFTSurvivalRegressionWrapper private (
pipeline: PipelineModel,
features: Array[String]) {

private val aftModel: AFTSurvivalRegressionModel =
pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]

lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) {
Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
} else {
aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
}

lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) {
Array("(Intercept)") ++ features ++ Array("Log(scale)")
} else {
features ++ Array("Log(scale)")
}

def transform(dataset: DataFrame): DataFrame = {
pipeline.transform(dataset)
}
}

private[r] object AFTSurvivalRegressionWrapper {

private def formulaRewrite(formula: String): (String, String) = {
var rewritedFormula: String = null
var censorCol: String = null

val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r
try {
val regex(label, censor, features) = formula
// TODO: Support dot operator.
if (features.contains(".")) {
throw new UnsupportedOperationException(
"Terms of survreg formula can not support dot operator.")
}
rewritedFormula = label.trim + "~" + features.trim
censorCol = censor.trim
} catch {
case e: MatchError =>
throw new SparkException(s"Could not parse formula: $formula")
}

(rewritedFormula, censorCol)
}


def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {

val (rewritedFormula, censorCol) = formulaRewrite(formula)

val rFormula = new RFormula().setFormula(rewritedFormula)
val rFormulaModel = rFormula.fit(data)

// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)

val aft = new AFTSurvivalRegression()
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
.fit(data)

new AFTSurvivalRegressionWrapper(pipeline, features)
}
}

0 comments on commit 13cbb2d

Please sign in to comment.