Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Competing Risks #433

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Collate:
'PredictionDens.R'
'PredictionSurv.R'
'RcppExports.R'
'TaskCompRisks.R'
'TaskDens.R'
'TaskDens_zzz.R'
'TaskGeneratorCoxed.R'
Expand All @@ -134,6 +135,7 @@ Collate:
'TaskSurv_zzz.R'
'as_prediction_dens.R'
'as_prediction_surv.R'
'as_task_cmprisk.R'
'as_task_dens.R'
'as_task_surv.R'
'assertions.R'
Expand Down
6 changes: 5 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ S3method(as_prediction_dens,PredictionDens)
S3method(as_prediction_dens,data.frame)
S3method(as_prediction_surv,PredictionSurv)
S3method(as_prediction_surv,data.frame)
S3method(as_task_cmprisk,DataBackend)
S3method(as_task_cmprisk,TaskCompRisks)
S3method(as_task_cmprisk,data.frame)
S3method(as_task_dens,DataBackend)
S3method(as_task_dens,TaskDens)
S3method(as_task_dens,data.frame)
S3method(as_task_surv,DataBackend)
S3method(as_task_surv,TaskSurv)
S3method(as_task_surv,data.frame)
S3method(as_task_surv,formula)
S3method(autoplot,PredictionSurv)
S3method(autoplot,TaskDens)
S3method(autoplot,TaskSurv)
Expand Down Expand Up @@ -79,13 +81,15 @@ export(PipeOpTaskSurvClassifDiscTime)
export(PipeOpTaskSurvClassifIPCW)
export(PredictionDens)
export(PredictionSurv)
export(TaskCompRisks)
export(TaskDens)
export(TaskGeneratorCoxed)
export(TaskGeneratorSimdens)
export(TaskGeneratorSimsurv)
export(TaskSurv)
export(as_prediction_dens)
export(as_prediction_surv)
export(as_task_cmprisk)
export(as_task_dens)
export(as_task_surv)
export(assert_surv)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# mlr3proba 0.8.0

* refactor: `TaskSurv` uses only right, left or interval censoring, simplified code a lot in the methods
* feat: add `TaskCompRisks` class and `as_task_cmprk()` S3 methods (support for right-censored data only)

# mlr3proba 0.7.4

* fix + update `MeasureSurv`: survival measure labels are now printed and the `obs_loss` property is now supported
Expand Down
4 changes: 3 additions & 1 deletion R/LearnerSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ LearnerSurv = R6Class("LearnerSurv",
initialize = function(id, param_set = ps(), predict_types = "distr",
feature_types = character(), properties = character(),
packages = character(), label = NA_character_, man = NA_character_) {

super$initialize(
id = id, task_type = "surv", param_set = param_set, predict_types = predict_types,
feature_types = feature_types, properties = properties,
packages = c("mlr3proba", packages), label = label, man = man)
packages = c("mlr3proba", packages), label = label, man = man
)
}
)
)
2 changes: 1 addition & 1 deletion R/MeasureDens.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' Default density measures: [`dens.logloss`][mlr_measures_dens.logloss]
#' @export
MeasureDens = R6Class("MeasureDens",
inherit = Measure, cloneable = FALSE,
inherit = Measure,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' Default survival measure: [`surv.cindex`][mlr_measures_surv.cindex]
#' @export
MeasureSurv = R6Class("MeasureSurv",
inherit = Measure, cloneable = FALSE,
inherit = Measure,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvClassifDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
private = list(
.train = function(input) {
task = input[[1L]]
assert_true(task$censtype == "right")
assert_true(task$cens_type == "right")
data = task$data()

if ("disc_status" %in% colnames(task$data())) {
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvClassifIPCW.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ PipeOpTaskSurvClassifIPCW = R6Class("PipeOpTaskSurvClassifIPCW",
task = input[[1]]

# checks
assert_true(task$censtype == "right")
assert_true(task$cens_type == "right")
tau = assert_numeric(self$param_set$values$tau, null.ok = FALSE)
max_event_time = max(task$unique_event_times())
stopifnot(tau <= max_event_time)
Expand Down
8 changes: 2 additions & 6 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#' @export
as_prediction.PredictionDataSurv = function(x, check = TRUE, ...) { # nolint
as_prediction.PredictionDataSurv = function(x, check = TRUE, ...) {
invoke(PredictionSurv$new, check = check, .args = x)
}


#' @export
check_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint

check_prediction_data.PredictionDataSurv = function(pdata, ...) {
n = length(assert_row_ids(pdata$row_ids))
assert_surv(pdata$truth, "Surv", len = n, any.missing = TRUE, null.ok = TRUE)
assert_numeric(pdata$crank, len = n, any.missing = FALSE, null.ok = FALSE)
Expand All @@ -24,7 +22,6 @@ check_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint
pdata
}


#' @export
is_missing_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint
miss = logical(length(pdata$row_ids))
Expand All @@ -44,7 +41,6 @@ is_missing_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint
pdata$row_ids[miss]
}


#' @export
c.PredictionDataSurv = function(..., keep_duplicates = TRUE) {
dots = list(...)
Expand Down
7 changes: 2 additions & 5 deletions R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ PredictionSurv = R6Class("PredictionSurv",
),

private = list(
.censtype = NULL,
.distr = function() self$data$distr %??% NA_real_,
.simplify_distr = function(x) {
if (inherits(x, c("Matdist", "Arrdist"))) {
Expand Down Expand Up @@ -189,16 +188,14 @@ PredictionSurv = R6Class("PredictionSurv",
)
)


#' @export
as.data.table.PredictionSurv = function(x, ...) { # nolint

as.data.table.PredictionSurv = function(x, ...) {
tab = as.data.table(x$data[c("row_ids", "crank", "lp", "response")])
tab$time = x$data$truth[, 1L]
tab$status = as.logical(x$data$truth[, 2L])
if ("distr" %in% x$predict_types) {
# annoyingly need this many lists to get nice printing
tab$distr = list(list(list(r6_private(x)$.distr())))
tab$distr = list(list(list(get_private(x)$.distr())))
}
setcolorder(tab, c("row_ids", "time", "status"))[]
}
206 changes: 206 additions & 0 deletions R/TaskCompRisks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#' @title Competing Risks Task
#'
#' @description
#' This task extends [mlr3::Task] and [mlr3::TaskSupervised] to handle survival
#' problems with **competing risks**.
#' The target variable consists of survival times and an event indicator, which
#' must be a non-negative integer in the set \eqn{(0,1,2,...,K)}.
#' \eqn{0} represents censored observations, while other integers correspond to
#' distinct competing events.
#' Every row corresponds to one subject/observation.
#'
#' Predefined tasks are stored in [mlr3::mlr_tasks].
#'
#' The `task_type` is set to `"cmprsk"`.
#'
#' **Note:** Currently only right-censoring is supported.
#'
#' @template param_rows
#'
#' @family Task
#' @examples
#' library(mlr3)
#' task = tsk("pbc")
#'
#' # meta data
#' task$target_names # target is always (time, status) for right-censoring tasks
#' task$feature_names
#' task$formula()
#'
#' # survival data
#' task$truth() # survival::Surv() object
#' task$times() # (unsorted) times
#' task$event() # event indicators (0 = censored, >0 = different causes)
#' task$unique_times() # sorted unique times
#' task$unique_event_times() # sorted unique event times (from any cause)
#' task$aalen_johansen(strata = "sex") # Aalen-Johansen estimator
#'
#' # proportion of censored observations across all dataset
#' task$cens_prop()
#'
#' @export
TaskCompRisks = R6Class("TaskCompRisks",
inherit = TaskSupervised,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @details
#' Only right-censoring competing risk tasks are currently supported.
#'
#' @template param_id
#' @template param_backend
#' @param time (`character(1)`)\cr
#' Name of the column for event time.
#' @param event (`character(1)`)\cr
#' Name of the column giving the event indicator (\eqn{0} corresponds to
#' censoring, values \eqn{> 0} correspond to competing events.
#' @param label (`character(1)`)\cr
#' Label for the new instance.
initialize = function(id, backend, time = "time", event = "event",
label = NA_character_) {
# only right-censoring supported
private$.cens_type = "right"
backend = as_data_backend(backend)

# check event is an integer starting from 0
event_col = get_private(backend)$.data[, event, with = FALSE][[1L]]
assert_integerish(event_col, lower = 0L, any.missing = FALSE)

# check that there is at least two competing events
n_cmp_events = sum(unique(event_col) != 0)
if (n_cmp_events < 2) {
stopf("Define at least two competing events, there are only %i in the data",
n_cmp_events)
}

# keep all the event levels
private$.event_levels = levels(as.factor(event_col))

super$initialize(
id = id, task_type = "cmprsk", backend = backend,
target = c(time, event), label = label
)
},

#' @description
#' True response for specified `row_ids`. This is the multi-state format
#' using [Surv][survival::Surv()] with the `event` target column as a `factor`:
#' `Surv(time, as.factor(event))`
#'
#' Defaults to all rows with role `"use"`.
#'
#' @return [survival::Surv()].
truth = function(rows = NULL) {
tn = self$target_names
data = self$data(rows = rows, cols = self$target_names)
times = data[[tn[1L]]]
event = data[[tn[2L]]]

args = list(time = times, event = as.factor(event))
invoke(Surv, .args = args)
},

#' @description
#' Creates a formula for competing risk models with [survival::Surv()] on
#' the LHS (left hand side).
#'
#' @param rhs
#' If `NULL`, RHS (right hand side) is `"."`, otherwise RHS is `"rhs"`.
#'
#' @return [stats::formula()].
formula = function(rhs = NULL) {
tn = self$target_names
lhs = sprintf("Surv(%s, as.factor(%s))", tn[1L], tn[2L])
formulate(lhs, rhs %??% ".", env = getNamespace("survival"))
},

#' @description
#' Returns the (unsorted) outcome times.
#' @return `numeric()`
times = function(rows = NULL) {
truth = self$truth(rows)
as.numeric(truth[, 1L])
},

#' @description
#' Returns the event indicator.
#' @return `integer()`
event = function(rows = NULL) {
# avoid going via `truth()[, 2L]` because `Surv()` + subsetting by `rows`
# results in different encoding than expected
event = self$data(rows = rows, cols = self$target_names[[2L]])[[1L]]
as.integer(event)
},

#' @description
#' Returns the sorted unique outcome times.
#' @return `numeric()`
unique_times = function(rows = NULL) {
sort(unique(self$times(rows)))
},

#' @description
#' Returns the sorted unique event outcome times (by any cause).
#' @return `numeric()`
unique_event_times = function(rows = NULL) {
sort(unique(self$times(rows)[self$event(rows) != 0]))
},

#' @description
#' Calls [survival::survfit()] to calculate the Aalen–Johansen estimator.
#'
#' @param strata (`character()`)\cr
#' Stratification variables to use.
#' @param rows (`integer()`)\cr
#' Subset of row indices.
#' @param ... (any)\cr
#' Additional arguments passed down to [survival::survfit.formula()].
#' @return [survival::survfit.object].
aalen_johansen = function(strata = NULL, rows = NULL, ...) {
assert_character(strata, null.ok = TRUE)
f = self$formula(strata %??% 1)
cols = c(self$target_names, intersect(self$backend$colnames, strata))
data = self$data(rows = rows, cols = cols)
survival::survfit(f, data = data, ...)
},

#' @description
#' Returns the **proportion of censoring** for this competing risks task.
#' By default, this is returned for all observations, otherwise only the
#' specified ones (`rows`).
#'
#' @return `numeric()`
cens_prop = function(rows = NULL) {
event = self$event(rows)
total_censored = sum(event == 0)
n_obs = length(event)

total_censored / n_obs
}
),

active = list(
#' @field cens_type (`character(1)`)\cr
#' Returns the type of censoring.
#'
#' Currently, only the `"right"` censoring type is fully supported.
#' The API might change in the future to support left and interval censoring.
cens_type = function(rhs) {
assert_ro_binding(rhs)
private$.cens_type
},

#' @field cmp_events (`character(1)`)\cr
#' Returns the names of the competing events.
cmp_events = function(rhs) {
assert_ro_binding(rhs)
setdiff(private$.event_levels, "0")
}
),

private = list(
.cens_type = NULL,
.event_levels = NULL
)
)
Loading
Loading