-
-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add qgam learner * Tests for qgam learner * small corrections * fix style issues * more style fixes * Apply changes from code review * check that features in form and column_roles are the same
- Loading branch information
Showing
8 changed files
with
332 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,8 @@ Authors@R: c( | |
person("John", "Zobolas", ,"[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0002-3609-8674")), | ||
person("Lukas", "Burk", , "[email protected]", role = "ctb", | ||
comment = c(ORCID = "0000-0001-7528-3795")) | ||
comment = c(ORCID = "0000-0001-7528-3795")), | ||
person("Lona", "Koers", , "[email protected]", role = "ctb") | ||
) | ||
Description: Extra learners for use in mlr3. | ||
License: LGPL-3 | ||
|
@@ -85,6 +86,7 @@ Suggests: | |
pracma, | ||
prioritylasso (>= 0.3.1), | ||
pseudo, | ||
qgam, | ||
randomForest, | ||
randomPlantedForest, | ||
randomForestSRC (>= 3.3.0), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#' @title Regression Quantile Generalized Additive Model Learner | ||
#' @author lona-k | ||
#' @name mlr_learners_regr.qgam | ||
#' | ||
#' @description | ||
#' Quantile Regression with generalized additive models. | ||
#' Calls [qgam::qgam()] from package \CRANpkg{qgam}. | ||
#' | ||
#' @section Form: | ||
#' For the `form` parameter, a gam formula specific to the [Task][mlr3::Task] is required (see example and `?mgcv::formula.gam`). | ||
#' If no formula is provided, a fallback formula using all features in the task is used that will make the Learner behave like Linear Quantile Regression. | ||
#' The features specified in the formula need to be the same as columns with col_roles "feature" in the task. | ||
#' | ||
#' @section Quantile: | ||
#' The quantile for the Learner, i.e. `qu` parameter from [qgam::qgam()], is set using the value specified in `learner$quantiles`. | ||
#' | ||
#' @templateVar id regr.qgam | ||
#' @template learner | ||
#' | ||
#' @references | ||
#' `r format_bib("faisolo2017qgam")` | ||
#' | ||
#' @template seealso_learner | ||
#' @examplesIf requireNamespace("qgam", quietly = TRUE) | ||
#' # simple example | ||
# t = mlr3::tsk("mtcars") | ||
# l = mlr3::lrn("regr.qgam") | ||
# t$select(c("cyl", "am", "disp", "hp")) | ||
# l$param_set$values$form = mpg ~ cyl + am + s(disp) + s(hp) | ||
# l$quantiles = 0.25 | ||
# l$train(t) | ||
# l$model | ||
# l$predict(t) | ||
#' @export | ||
LearnerRegrQGam = R6Class("LearnerRegrQGam", | ||
inherit = LearnerRegr, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function() { | ||
param_set = ps( | ||
form = p_uty(tags = "train"), | ||
lsig = p_dbl(tags = "train"), | ||
err = p_dbl(lower = 0, upper = 1, tags = "train"), | ||
cluster = p_uty(default = NULL, tags = "train"), | ||
multicore = p_lgl(tags = "train"), | ||
ncores = p_dbl(tags = "train"), | ||
paropts = p_uty(default = list(), tags = "train"), | ||
link = p_uty(default = "identity", tags = "train"), | ||
argGam = p_uty(custom_check = crate(function(x) { | ||
checkmate::check_list(x, names = "unique", null.ok = TRUE) | ||
}), tags = "train"), | ||
block.size = p_int(default = 1000L, tags = "predict"), | ||
unconditional = p_lgl(default = FALSE, tags = "predict") | ||
) | ||
|
||
super$initialize( | ||
id = "regr.qgam", | ||
packages = "qgam", | ||
feature_types = c("logical", "integer", "numeric", "factor"), | ||
predict_types = c("response", "se", "quantiles"), | ||
param_set = param_set, | ||
properties = "weights", | ||
man = "mlr3extralearners::mlr_learners_regr.qgam", | ||
label = "Regression Quantile Generalized Additive Model Learner" | ||
) | ||
|
||
self$predict_type = "quantiles" | ||
self$quantiles = 0.5 | ||
} | ||
), | ||
private = list( | ||
.train = function(task) { | ||
data = task$data(cols = c(task$feature_names, task$target_names)) | ||
|
||
# get parameters for training | ||
pars = self$param_set$get_values(tags = "train") | ||
control_pars = if (length(pars$link)) list(pars$link) else list(NULL) | ||
|
||
args_gam = formalArgs(mgcv::gam)[formalArgs(mgcv::gam) %nin% c("formula", "family", "data")] | ||
if (length(pars$argGam)) { | ||
checkmate::assert_subset(names(pars$argGam), choices = args_gam, empty.ok = FALSE) | ||
} | ||
|
||
arg_gam_pars = pars$argGam | ||
pars = pars[names(pars) %nin% c("argGam", "link")] | ||
|
||
if ("weights" %in% task$properties) { | ||
arg_gam_pars = insert_named(arg_gam_pars, list(weights = task$weights$weight)) | ||
} | ||
|
||
if (is.null(pars$form)) { | ||
form = stats::reformulate(task$feature_names, response = task$target_names) | ||
pars$form = form | ||
} | ||
|
||
checkmate::assert_set_equal(all.vars(pars$form)[-1], task$col_roles$feature) | ||
checkmate::assert_set_equal(all.vars(pars$form)[[1]], task$col_roles$target) | ||
invoke( | ||
qgam::qgam, | ||
qu = self$quantiles, | ||
data = data, | ||
.args = pars, | ||
control = control_pars, | ||
argGam = arg_gam_pars | ||
) | ||
}, | ||
.predict = function(task) { # qgam uses predict.gam | ||
# get parameters with tag "predict" | ||
pars = self$param_set$get_values(tags = "predict") | ||
|
||
# get newdata and ensure same ordering in train and predict | ||
newdata = ordered_features(task, self) | ||
|
||
include_se = (self$predict_type == "se") | ||
|
||
preds = invoke( | ||
predict, | ||
self$model, | ||
newdata = newdata, | ||
type = "response", | ||
newdata.guaranteed = TRUE, | ||
se.fit = include_se, | ||
.args = pars | ||
) | ||
|
||
if (include_se) { # se and response | ||
list(response = preds$fit, se = preds$se) | ||
} else if (self$predict_type == "quantiles") { | ||
quantiles = matrix(preds, ncol = 1) | ||
attr(quantiles, "probs") = self$quantiles | ||
attr(quantiles, "response") = self$quantiles | ||
list(quantiles = quantiles) | ||
} else { | ||
list(response = preds) | ||
} | ||
} | ||
) | ||
) | ||
|
||
.extralrns_dict$add("regr.qgam", LearnerRegrQGam) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
test_that("regr.qgam train", { | ||
learner = lrn("regr.qgam") | ||
fun = qgam::qgam | ||
exclude = c( | ||
"data", # handled internally | ||
"control", # handled internally | ||
"link", # handled internally via control | ||
"qu" # handled internally via self$quantiles | ||
) | ||
|
||
# note that you can also pass a list of functions in case $.train calls more than one | ||
# function, e.g. for control arguments | ||
paramtest = run_paramtest(learner, fun, exclude, tag = "train") | ||
expect_paramtest(paramtest) | ||
}) | ||
|
||
test_that("regr.qgam predict", { | ||
learner = lrn("regr.qgam") | ||
fun = mgcv::predict.gam | ||
exclude = c( | ||
"object", # handled internally | ||
"newdata", # handled internally | ||
"type", # handled internally | ||
"newdata.guaranteed", # handled internally | ||
"na.action", # handled internally | ||
"se.fit", # handled internally | ||
"terms", # not relevant for predict type "se" or "response" | ||
"exclude", # not relevant for predict type "se" or "response" | ||
"iterms.type", # not relevant for predict type "se" or "response" | ||
"qu" | ||
) | ||
|
||
paramtest = run_paramtest(learner, fun, exclude, tag = "predict") | ||
expect_paramtest(paramtest) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
test_that("autotest", { | ||
learner = lrn("regr.qgam") | ||
expect_learner(learner) | ||
result = run_autotest(learner, exclude = "utf8_feature_names") | ||
expect_true(result, info = result$error) | ||
}) |