Skip to content

Commit

Permalink
Quantile GAM learner (#404)
Browse files Browse the repository at this point in the history
* Add qgam learner

* Tests for qgam learner

* small corrections

* fix style issues

* more style fixes

* Apply changes from code review

* check that features in form and column_roles are the same
  • Loading branch information
lona-k authored Jan 28, 2025
1 parent dd5e37c commit 39ba173
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 1 deletion.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Authors@R: c(
person("John", "Zobolas", ,"[email protected]", role = "ctb",
comment = c(ORCID = "0000-0002-3609-8674")),
person("Lukas", "Burk", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795"))
comment = c(ORCID = "0000-0001-7528-3795")),
person("Lona", "Koers", , "[email protected]", role = "ctb")
)
Description: Extra learners for use in mlr3.
License: LGPL-3
Expand Down Expand Up @@ -85,6 +86,7 @@ Suggests:
pracma,
prioritylasso (>= 0.3.1),
pseudo,
qgam,
randomForest,
randomPlantedForest,
randomForestSRC (>= 3.3.0),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export(LearnerRegrMars)
export(LearnerRegrMob)
export(LearnerRegrMultilayerPerceptron)
export(LearnerRegrPriorityLasso)
export(LearnerRegrQGam)
export(LearnerRegrREPTree)
export(LearnerRegrRSM)
export(LearnerRegrRVM)
Expand Down
9 changes: 9 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -711,5 +711,14 @@ bibentries = c( # nolint start
publisher = "ACM Press",
title = "Large margin classification using the perceptron algorithm",
year = "1998"
),
faisolo2017qgam = bibentry("article",
title = "Fast Calibrated Additive Quantile Regression",
author = "Fasiolo, Matteo and Wood, Simon N. and Zaffran, Margaux and Nedellec, Raphael and Goude, Yannig",
year = "2017",
journal = "Journal of the American Statistical Association",
volume = "116",
pages = "1402--1412",
doi = "10.1080/01621459.2020.1725521"
)
) # nolint end
141 changes: 141 additions & 0 deletions R/learner_qgam_regr_qgam.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#' @title Regression Quantile Generalized Additive Model Learner
#' @author lona-k
#' @name mlr_learners_regr.qgam
#'
#' @description
#' Quantile Regression with generalized additive models.
#' Calls [qgam::qgam()] from package \CRANpkg{qgam}.
#'
#' @section Form:
#' For the `form` parameter, a gam formula specific to the [Task][mlr3::Task] is required (see example and `?mgcv::formula.gam`).
#' If no formula is provided, a fallback formula using all features in the task is used that will make the Learner behave like Linear Quantile Regression.
#' The features specified in the formula need to be the same as columns with col_roles "feature" in the task.
#'
#' @section Quantile:
#' The quantile for the Learner, i.e. `qu` parameter from [qgam::qgam()], is set using the value specified in `learner$quantiles`.
#'
#' @templateVar id regr.qgam
#' @template learner
#'
#' @references
#' `r format_bib("faisolo2017qgam")`
#'
#' @template seealso_learner
#' @examplesIf requireNamespace("qgam", quietly = TRUE)
#' # simple example
# t = mlr3::tsk("mtcars")
# l = mlr3::lrn("regr.qgam")
# t$select(c("cyl", "am", "disp", "hp"))
# l$param_set$values$form = mpg ~ cyl + am + s(disp) + s(hp)
# l$quantiles = 0.25
# l$train(t)
# l$model
# l$predict(t)
#' @export
LearnerRegrQGam = R6Class("LearnerRegrQGam",
inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
form = p_uty(tags = "train"),
lsig = p_dbl(tags = "train"),
err = p_dbl(lower = 0, upper = 1, tags = "train"),
cluster = p_uty(default = NULL, tags = "train"),
multicore = p_lgl(tags = "train"),
ncores = p_dbl(tags = "train"),
paropts = p_uty(default = list(), tags = "train"),
link = p_uty(default = "identity", tags = "train"),
argGam = p_uty(custom_check = crate(function(x) {
checkmate::check_list(x, names = "unique", null.ok = TRUE)
}), tags = "train"),
block.size = p_int(default = 1000L, tags = "predict"),
unconditional = p_lgl(default = FALSE, tags = "predict")
)

super$initialize(
id = "regr.qgam",
packages = "qgam",
feature_types = c("logical", "integer", "numeric", "factor"),
predict_types = c("response", "se", "quantiles"),
param_set = param_set,
properties = "weights",
man = "mlr3extralearners::mlr_learners_regr.qgam",
label = "Regression Quantile Generalized Additive Model Learner"
)

self$predict_type = "quantiles"
self$quantiles = 0.5
}
),
private = list(
.train = function(task) {
data = task$data(cols = c(task$feature_names, task$target_names))

# get parameters for training
pars = self$param_set$get_values(tags = "train")
control_pars = if (length(pars$link)) list(pars$link) else list(NULL)

args_gam = formalArgs(mgcv::gam)[formalArgs(mgcv::gam) %nin% c("formula", "family", "data")]
if (length(pars$argGam)) {
checkmate::assert_subset(names(pars$argGam), choices = args_gam, empty.ok = FALSE)
}

arg_gam_pars = pars$argGam
pars = pars[names(pars) %nin% c("argGam", "link")]

if ("weights" %in% task$properties) {
arg_gam_pars = insert_named(arg_gam_pars, list(weights = task$weights$weight))
}

if (is.null(pars$form)) {
form = stats::reformulate(task$feature_names, response = task$target_names)
pars$form = form
}

checkmate::assert_set_equal(all.vars(pars$form)[-1], task$col_roles$feature)
checkmate::assert_set_equal(all.vars(pars$form)[[1]], task$col_roles$target)
invoke(
qgam::qgam,
qu = self$quantiles,
data = data,
.args = pars,
control = control_pars,
argGam = arg_gam_pars
)
},
.predict = function(task) { # qgam uses predict.gam
# get parameters with tag "predict"
pars = self$param_set$get_values(tags = "predict")

# get newdata and ensure same ordering in train and predict
newdata = ordered_features(task, self)

include_se = (self$predict_type == "se")

preds = invoke(
predict,
self$model,
newdata = newdata,
type = "response",
newdata.guaranteed = TRUE,
se.fit = include_se,
.args = pars
)

if (include_se) { # se and response
list(response = preds$fit, se = preds$se)
} else if (self$predict_type == "quantiles") {
quantiles = matrix(preds, ncol = 1)
attr(quantiles, "probs") = self$quantiles
attr(quantiles, "response") = self$quantiles
list(quantiles = quantiles)
} else {
list(response = preds)
}
}
)
)

.extralrns_dict$add("regr.qgam", LearnerRegrQGam)
1 change: 1 addition & 0 deletions man/mlr3extralearners-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

136 changes: 136 additions & 0 deletions man/mlr_learners_regr.qgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions tests/testthat/test_paramtest_qgam_regr_qgam.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
test_that("regr.qgam train", {
learner = lrn("regr.qgam")
fun = qgam::qgam
exclude = c(
"data", # handled internally
"control", # handled internally
"link", # handled internally via control
"qu" # handled internally via self$quantiles
)

# note that you can also pass a list of functions in case $.train calls more than one
# function, e.g. for control arguments
paramtest = run_paramtest(learner, fun, exclude, tag = "train")
expect_paramtest(paramtest)
})

test_that("regr.qgam predict", {
learner = lrn("regr.qgam")
fun = mgcv::predict.gam
exclude = c(
"object", # handled internally
"newdata", # handled internally
"type", # handled internally
"newdata.guaranteed", # handled internally
"na.action", # handled internally
"se.fit", # handled internally
"terms", # not relevant for predict type "se" or "response"
"exclude", # not relevant for predict type "se" or "response"
"iterms.type", # not relevant for predict type "se" or "response"
"qu"
)

paramtest = run_paramtest(learner, fun, exclude, tag = "predict")
expect_paramtest(paramtest)
})
6 changes: 6 additions & 0 deletions tests/testthat/test_qgam_regr_qgam.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
test_that("autotest", {
learner = lrn("regr.qgam")
expect_learner(learner)
result = run_autotest(learner, exclude = "utf8_feature_names")
expect_true(result, info = result$error)
})

0 comments on commit 39ba173

Please sign in to comment.