From 6eb5a6e7104587932cea51c61d92ef43824231bb Mon Sep 17 00:00:00 2001 From: Philip Studener Date: Tue, 24 Sep 2024 16:48:05 +0200 Subject: [PATCH 1/9] draft task conversion pipeop --- R/PipeOpTaskSurvRegrPEM.R | 213 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 R/PipeOpTaskSurvRegrPEM.R diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R new file mode 100644 index 000000000..88a885b1d --- /dev/null +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -0,0 +1,213 @@ +#' @title PipeOpTaskSurvRegrPEM +#' @name mlr_pipeops_trafotask_survregr_PEM +#' @template param_pipelines +#' +#' @description +#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous +#' time into multiple time intervals for each observation. +#' This transformation creates a new target variable `disc_status` that indicates +#' whether an event occurred within each time interval. +#' +#' @section Dictionary: +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the +#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] +#' or with the associated sugar function [mlr3pipelines::po()]: +#' ``` +#' PipeOpTaskSurvRegrPEM$new() +#' mlr_pipeops$get("trafotask_survregr_PEM") +#' po("trafotask_survregr_PEM") +#' ``` +#' +#' @section Input and Output Channels: +#' [PipeOpTaskSurvClassifDiscTime] has one input channel named "input", and two +#' output channels, one named "output" and the other "transformed_data". +#' +#' During training, the "output" is the "input" [TaskSurv] transformed to a +#' [TaskRegr][mlr3::TaskRegr]. +#' The target column is named `"disc_status"` and indicates whether an event occurred +#' in each time interval. +#' An additional feature named `"tend"` contains the end time point of each interval. +#' Lastly, the "output" task has an offset column `"offset"`. +#' The "transformed_data" is an empty [data.table][data.table::data.table]. +#' +#' During prediction, the "input" [TaskSurv] is transformed to the "output" +#' [TaskRegr][mlr3::TaskRegr] with `"disc_status"` as target and the `"tend"` +#' as well as `"offset"` feature included. +#' The "transformed_data" is a [data.table] with columns the `"disc_status"` +#' target of the "output" task, the `"id"` (original observation ids), +#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval). +#' This "transformed_data" is only meant to be used with the [PipeOpPredRegrSurvPEM]. +#' +#' @section State: +#' The `$state` contains information about the `cut` parameter used. +#' +#' @section Parameters: +#' The parameters are +#' +#' * `cut :: numeric()`\cr +#' Split points, used to partition the data into intervals based on the `time` column. +#' If unspecified, all unique event times will be used. +#' If `cut` is a single integer, it will be interpreted as the number of equidistant +#' intervals from 0 until the maximum event time. +#' * `max_time :: numeric(1)`\cr +#' If `cut` is unspecified, this will be the last possible event time. +#' All event times after `max_time` will be administratively censored at `max_time.` +#' Needs to be greater than the minimum event time in the given task. +#' +#' @examples +#' +#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE) +#' \dontrun{ +#' library(mlr3) +#' library(mlr3learners) +#' library(mlr3pipelines) +#' +#' task = tsk("lung") +#' +#' # transform the survival task to a poisson regression task +#' # all unique event times are used as cutpoints +#' po_disc = po("trafotask_survregr_PEM") +#' task_regr = po_disc$train(list(task))[[1L]] +#' +#' # the end time points of the discrete time intervals +#' unique(task_regr$data(cols = "tend"))[[1L]] +#' +#' # train a classification learner +#' learner = lrn("classif.log_reg", predict_type = "prob") +#' learner$train(task_regr) +#' } +#' } +#' +#' +#' @family PipeOps +#' @family Transformation PipeOps +#' @export +PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", + inherit = mlr3pipelines::PipeOp, + + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(id = "trafotask_survregr_PEM") { + param_set = ps( + cut = p_uty(default = NULL), + max_time = p_dbl(0, default = NULL, special_vals = list(NULL)) + ) + super$initialize( + id = id, + param_set = param_set, + input = data.table( + name = "input", + train = "TaskSurv", + predict = "TaskSurv" + ), + output = data.table( + name = c("output", "transformed_data"), + train = c("TaskRegr", "data.table"), + predict = c("TaskRegr", "data.table") + ) + ) + } + ), + + private = list( + .train = function(input) { + task = input[[1L]] + assert_true(task$censtype == "right") + data = task$data() + + if ("disc_status" %in% colnames(task$data())) { + stop("\"disc_status\" can not be a column in the input data.") + } + + cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0) + max_time = self$param_set$values$max_time + + time_var = task$target_names[1] + event_var = task$target_names[2] + if (testInt(cut, lower = 1)) { + cut = seq(0, data[get(event_var) == 1, max(get(time_var))], length.out = cut + 1) + } + + if (!is.null(max_time)) { + assert(max_time > data[get(event_var) == 1, min(get(time_var))], + "max_time must be greater than the minimum event time.") + } + + form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") + + long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time) + self$state$cut = attributes(long_data)$trafo_args$cut + long_data = as.data.table(long_data) + setnames(long_data, old = "ped_status", new = "disc_status") + + # remove some columns from `long_data` + long_data[, c("tstart", "interval") := NULL] + # keep id mapping + reps = table(long_data$id) + ids = rep(task$row_ids, times = reps) + id = NULL + long_data[, id := ids] + + task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data, + target = "disc_status") + task_disc$set_col_roles("id", roles = "name") + + list(task_disc, data.table()) + }, + + .predict = function(input) { + task = input[[1]] + data = task$data() + + # extract `cut` from `state` + cut = self$state$cut + + time_var = task$target_names[1] + event_var = task$target_names[2] + + max_time = max(cut) + time = data[[time_var]] + data[[time_var]] = max_time + + status = data[[event_var]] + data[[event_var]] = 1 + + # update form + form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") + + long_data = as.data.table(pammtools::as_ped(data, formula = form, cut = cut)) + setnames(long_data, old = "ped_status", new = "disc_status") + + disc_status = id = tend = obs_times = NULL # fixing global binding notes of data.table + long_data[, disc_status := 0] + # set correct id + rows_per_id = nrow(long_data) / length(unique(long_data$id)) + long_data$obs_times = rep(time, each = rows_per_id) + ids = rep(task$row_ids, each = rows_per_id) + long_data[, id := ids] + + # set correct disc_status + reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count + status = rep(status, times = reps) + long_data[long_data[, .I[tend >= obs_times], by = id]$V1, disc_status := status] + + # remove some columns from `long_data` + long_data[, c("tstart", "interval", "obs_times") := NULL] + task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data, + target = "disc_status") + task_disc$set_col_roles("id", roles = "name") + + # map observed times back + reps = table(long_data$id) + long_data$obs_times = rep(time, each = rows_per_id) + # subset transformed data + columns_to_keep = c("id", "obs_times", "tend", "disc_status", "offset") + long_data = long_data[, columns_to_keep, with = FALSE] + + list(task_disc, long_data) + } + ) +) + +register_pipeop("trafotask_survregr_PEM", PipeOpTaskSurvRegrPEM) From 4d49fa33ad1bf258934015992ea7b400fabbe36f Mon Sep 17 00:00:00 2001 From: studener Date: Thu, 26 Sep 2024 14:51:05 +0200 Subject: [PATCH 2/9] draft pipeop + pipeline --- R/PipeOpPredRegrSurvPEM.R | 95 +++++++++++++++++++++++++++++++++++++++ R/pipelines.R | 76 +++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 R/PipeOpPredRegrSurvPEM.R diff --git a/R/PipeOpPredRegrSurvPEM.R b/R/PipeOpPredRegrSurvPEM.R new file mode 100644 index 000000000..5a4a11ed2 --- /dev/null +++ b/R/PipeOpPredRegrSurvPEM.R @@ -0,0 +1,95 @@ +#' @title PipeOpPredRegrSurvPEM +#' @name mlr_pipeops_trafopred_regrsurv_PEM +#' +#' @description +#' Transform [PredictionRegr] to [PredictionSurv]. +#' +#' @section Dictionary: +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the +#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] +#' or with the associated sugar function [mlr3pipelines::po()]: +#' ``` +#' PipeOpPredRegrSurvPEM$new() +#' mlr_pipeops$get("trafopred_regrsurv_PEM") +#' po("trafopred_regrsurv_PEM") +#' ``` +#' +#' @section Input and Output Channels: +#' The input is a [PredictionRegr] and a [data.table][data.table::data.table] +#' with the transformed data both generated by [PipeOpTaskSurvRegrPEM]. +#' The output is the input [PredictionRegr] transformed to a [PredictionSurv]. +#' Only works during prediction phase. +#' +#' @family PipeOps +#' @family Transformation PipeOps +#' @export +PipeOpPredRegrSurvPEM = R6Class( + "PipeOpPredRegrSurvPEM", + inherit = mlr3pipelines::PipeOp, + + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + #' @param id (character(1))\cr + #' Identifier of the resulting object. + initialize = function(id = "trafopred_regrsurv_PEM") { + super$initialize( + id = id, + input = data.table( + name = c("input", "transformed_data"), + train = c("NULL", "data.table"), + predict = c("PredictionRegr", "data.table") + ), + output = data.table( + name = "output", + train = "NULL", + predict = "PredictionSurv" + ) + ) + } + ), + + private = list( + .predict = function(input) { + pred = input[[1]] + data = input[[2]] + assert_true(!is.null(pred$response)) + # probability of having the event (1) in each respective interval + # is the discrete-time hazard + data = cbind(data, dt_hazard = pred$response) + + # From theory, convert hazards to surv as prod(1 - h(t)) + rows_per_id = nrow(data) / length(unique(data$id)) + surv = t(vapply(unique(data$id), function(unique_id) { + 1 - cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])) + }, numeric(rows_per_id))) + + unique_end_times = sort(unique(data$tend)) + # coerce to distribution and crank + pred_list = .surv_return(times = unique_end_times, surv = surv) + + # select the real tend values by only selecting the last row of each id + # basically a slightly more complex unique() + real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0] + + ids = unique(data$id) + # select last row for every id => observed times + id = disc_status = NULL # to fix note + data = data[, .SD[.N, list(disc_status)], by = id] + + # create prediction object + p = PredictionSurv$new( + row_ids = ids, + crank = pred_list$crank, distr = pred_list$distr, + truth = Surv(real_tend, as.integer(as.character(data$disc_status)))) + + list(p) + }, + + .train = function(input) { + self$state = list() + list(input) + } + ) +) +register_pipeop("trafopred_regrsurv_PEM", PipeOpPredRegrSurvPEM) diff --git a/R/pipelines.R b/R/pipelines.R index 8587f6d44..f575b1013 100644 --- a/R/pipelines.R +++ b/R/pipelines.R @@ -659,6 +659,81 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL, gr } +#' @name mlr_graphs_survtoregr_PEM +#' @title Survival to Poisson Regression Reduction Pipeline +#' @description Wrapper around multiple [PipeOp][mlr3pipelines::PipeOp]s to help in creation +#' of complex survival reduction methods. +#' +#' @param learner [LearnerRegr][mlr3::LearnerRegr]\cr +#' Regression learner to fit the transformed [TaskRegr][mlr3::TaskRegr]. +#' `learner` must be able to handle `offset`. +#' @param cut `numeric()`\cr +#' Split points, used to partition the data into intervals. +#' If unspecified, all unique event times will be used. +#' If `cut` is a single integer, it will be interpreted as the number of equidistant +#' intervals from 0 until the maximum event time. +#' @param max_time `numeric(1)`\cr +#' If cut is unspecified, this will be the last possible event time. +#' All event times after max_time will be administratively censored at max_time. +#' @param graph_learner `logical(1)`\cr +#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a +#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`. +#' +#' @details +#' The pipeline consists of the following steps: +#' \enumerate{ +#' \item [PipeOpTaskSurvRegrPEM] Converts [TaskSurv] to a [TaskRegr][mlr3::TaskRegr]. +#' \item A [LearnerRegr] is fit and predicted on the new `TaskRegr`. +#' \item [PipeOpPredRegrSurvPEM] transforms the resulting [PredictionRegr][mlr3::PredictionRegr] +#' to [PredictionSurv]. +#' } +#' +#' @return [mlr3pipelines::Graph] or [mlr3pipelines::GraphLearner] +#' @family pipelines +#' +#' @examples +#' \dontrun{ +#' if (requireNamespace("mlr3pipelines", quietly = TRUE) && +#' requireNamespace("mlr3learners", quietly = TRUE)) { +#' +#' library(mlr3) +#' library(mlr3learners) +#' library(mlr3pipelines) +#' +#' task = tsk("lung") +#' part = partition(task) +#' +#' grlrn = ppl( +#' "survtoregr_PEM", +#' learner = lrn("regr.xgboost") +#' ) +#' grlrn$train(task, row_ids = part$train) +#' grlrn$predict(task, row_ids = part$test) +#' } +#' } +#' @export +pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL, + rhs = NULL, graph_learner = FALSE) { + # TODO: add assertions + + gr = mlr3pipelines::Graph$new() + gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time)) + gr$add_pipeop(mlr3pipelines::po("learner", learner)) + gr$add_pipeop(mlr3pipelines::po("nop")) + gr$add_pipeop(mlr3pipelines::po("trafopred_regrsurv_PEM")) + + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = learner$id, src_channel = "output", dst_channel = "input") + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "nop", src_channel = "transformed_data", dst_channel = "input") + gr$add_edge(src_id = learner$id, dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "input") + gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data") + + if (graph_learner) { + gr = mlr3pipelines::GraphLearner$new(gr) + } + + gr +} + register_graph("survaverager", pipeline_survaverager) register_graph("survbagging", pipeline_survbagging) register_graph("crankcompositor", pipeline_crankcompositor) @@ -667,3 +742,4 @@ register_graph("responsecompositor", pipeline_responsecompositor) register_graph("probregr", pipeline_probregr) register_graph("survtoregr", pipeline_survtoregr) register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime) +register_graph("survregr_PEM", pipeline_survtoregr_PEM) From d786d0807fd74d292ff733af9cb7171c598e6865 Mon Sep 17 00:00:00 2001 From: studener Date: Thu, 26 Sep 2024 16:58:50 +0200 Subject: [PATCH 3/9] update pred conversion pipeop --- R/PipeOpPredRegrSurvPEM.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PipeOpPredRegrSurvPEM.R b/R/PipeOpPredRegrSurvPEM.R index 5a4a11ed2..66a564f8f 100644 --- a/R/PipeOpPredRegrSurvPEM.R +++ b/R/PipeOpPredRegrSurvPEM.R @@ -61,7 +61,7 @@ PipeOpPredRegrSurvPEM = R6Class( # From theory, convert hazards to surv as prod(1 - h(t)) rows_per_id = nrow(data) / length(unique(data$id)) surv = t(vapply(unique(data$id), function(unique_id) { - 1 - cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])) + exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))) }, numeric(rows_per_id))) unique_end_times = sort(unique(data$tend)) From b7bc0c6785a9f10f2a4f36579595f77c16d63434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20G=C3=B6swein?= Date: Tue, 29 Oct 2024 11:01:15 +0100 Subject: [PATCH 4/9] added modelmatrix pipeop to PEM pipeline, changed variable naming to PEM, fixed minor bugs in PipeOp...PEM --- R/PipeOpPredRegrSurvPEM.R | 8 +++---- R/PipeOpTaskSurvRegrPEM.R | 49 ++++++++++++++++++++------------------- R/pipelines.R | 10 +++++++- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/R/PipeOpPredRegrSurvPEM.R b/R/PipeOpPredRegrSurvPEM.R index 66a564f8f..28d176763 100644 --- a/R/PipeOpPredRegrSurvPEM.R +++ b/R/PipeOpPredRegrSurvPEM.R @@ -58,7 +58,7 @@ PipeOpPredRegrSurvPEM = R6Class( # is the discrete-time hazard data = cbind(data, dt_hazard = pred$response) - # From theory, convert hazards to surv as prod(1 - h(t)) + # From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset))) rows_per_id = nrow(data) / length(unique(data$id)) surv = t(vapply(unique(data$id), function(unique_id) { exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))) @@ -74,14 +74,14 @@ PipeOpPredRegrSurvPEM = R6Class( ids = unique(data$id) # select last row for every id => observed times - id = disc_status = NULL # to fix note - data = data[, .SD[.N, list(disc_status)], by = id] + id = PEM_status = NULL # to fix note + data = data[, .SD[.N, list(PEM_status)], by = id] # create prediction object p = PredictionSurv$new( row_ids = ids, crank = pred_list$crank, distr = pred_list$distr, - truth = Surv(real_tend, as.integer(as.character(data$disc_status)))) + truth = Surv(real_tend, as.integer(as.character(data$PEM_status)))) list(p) }, diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R index 88a885b1d..2ef9fdf8a 100644 --- a/R/PipeOpTaskSurvRegrPEM.R +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -5,7 +5,7 @@ #' @description #' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous #' time into multiple time intervals for each observation. -#' This transformation creates a new target variable `disc_status` that indicates +#' This transformation creates a new target variable `PEM_status` that indicates #' whether an event occurred within each time interval. #' #' @section Dictionary: @@ -19,21 +19,21 @@ #' ``` #' #' @section Input and Output Channels: -#' [PipeOpTaskSurvClassifDiscTime] has one input channel named "input", and two +#' [PipeOpTaskSurvRegrPEM] has one input channel named "input", and two #' output channels, one named "output" and the other "transformed_data". #' #' During training, the "output" is the "input" [TaskSurv] transformed to a #' [TaskRegr][mlr3::TaskRegr]. -#' The target column is named `"disc_status"` and indicates whether an event occurred +#' The target column is named `"PEM_status"` and indicates whether an event occurred #' in each time interval. #' An additional feature named `"tend"` contains the end time point of each interval. #' Lastly, the "output" task has an offset column `"offset"`. #' The "transformed_data" is an empty [data.table][data.table::data.table]. #' #' During prediction, the "input" [TaskSurv] is transformed to the "output" -#' [TaskRegr][mlr3::TaskRegr] with `"disc_status"` as target and the `"tend"` +#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target and the `"tend"` #' as well as `"offset"` feature included. -#' The "transformed_data" is a [data.table] with columns the `"disc_status"` +#' The "transformed_data" is a [data.table] with columns the `"PEM_status"` #' target of the "output" task, the `"id"` (original observation ids), #' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval). #' This "transformed_data" is only meant to be used with the [PipeOpPredRegrSurvPEM]. @@ -58,6 +58,7 @@ #' #' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE) #' \dontrun{ +#' # Update documentation to match PEM #' library(mlr3) #' library(mlr3learners) #' library(mlr3pipelines) @@ -66,8 +67,8 @@ #' #' # transform the survival task to a poisson regression task #' # all unique event times are used as cutpoints -#' po_disc = po("trafotask_survregr_PEM") -#' task_regr = po_disc$train(list(task))[[1L]] +#' po_PEM = po("trafotask_survregr_PEM") +#' task_regr = po_PEM$train(list(task))[[1L]] #' #' # the end time points of the discrete time intervals #' unique(task_regr$data(cols = "tend"))[[1L]] @@ -116,8 +117,8 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", assert_true(task$censtype == "right") data = task$data() - if ("disc_status" %in% colnames(task$data())) { - stop("\"disc_status\" can not be a column in the input data.") + if ("PEM_status" %in% colnames(task$data())) { + stop("\"PEM_status\" can not be a column in the input data.") } cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0) @@ -139,7 +140,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time) self$state$cut = attributes(long_data)$trafo_args$cut long_data = as.data.table(long_data) - setnames(long_data, old = "ped_status", new = "disc_status") + setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM # remove some columns from `long_data` long_data[, c("tstart", "interval") := NULL] @@ -149,11 +150,11 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", id = NULL long_data[, id := ids] - task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data, - target = "disc_status") - task_disc$set_col_roles("id", roles = "name") + task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, + target = "PEM_status") + task_PEM$set_col_roles("id", roles = "group") - list(task_disc, data.table()) + list(task_PEM, data.table()) }, .predict = function(input) { @@ -177,35 +178,35 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") long_data = as.data.table(pammtools::as_ped(data, formula = form, cut = cut)) - setnames(long_data, old = "ped_status", new = "disc_status") + setnames(long_data, old = "ped_status", new = "PEM_status") - disc_status = id = tend = obs_times = NULL # fixing global binding notes of data.table - long_data[, disc_status := 0] + PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table + long_data[, PEM_status := 0] # set correct id rows_per_id = nrow(long_data) / length(unique(long_data$id)) long_data$obs_times = rep(time, each = rows_per_id) ids = rep(task$row_ids, each = rows_per_id) long_data[, id := ids] - # set correct disc_status + # set correct PEM_status reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count status = rep(status, times = reps) - long_data[long_data[, .I[tend >= obs_times], by = id]$V1, disc_status := status] + long_data[long_data[, .I[tend >= obs_times], by = id]$V1, PEM_status := status] # remove some columns from `long_data` long_data[, c("tstart", "interval", "obs_times") := NULL] - task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data, - target = "disc_status") - task_disc$set_col_roles("id", roles = "name") + task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, + target = "PEM_status") + task_PEM$set_col_roles("id", roles = "group") # map observed times back reps = table(long_data$id) long_data$obs_times = rep(time, each = rows_per_id) # subset transformed data - columns_to_keep = c("id", "obs_times", "tend", "disc_status", "offset") + columns_to_keep = c("id", "obs_times", "tend", "PEM_status", "offset") long_data = long_data[, columns_to_keep, with = FALSE] - list(task_disc, long_data) + list(task_PEM, long_data) } ) ) diff --git a/R/pipelines.R b/R/pipelines.R index f575b1013..7bce79af4 100644 --- a/R/pipelines.R +++ b/R/pipelines.R @@ -727,6 +727,14 @@ pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL, gr$add_edge(src_id = learner$id, dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "input") gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data") + + if (!is.null(rhs)) { + gr$edges = gr$edges[-1, ] + gr$add_pipeop(mlr3pipelines::po("modelmatrix", formula = formulate(rhs = rhs, quote = "left"))) + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "modelmatrix", src_channel = "output") + gr$add_edge(src_id = "modelmatrix", dst_id = learner$id, src_channel = "output", dst_channel = "input") + } + if (graph_learner) { gr = mlr3pipelines::GraphLearner$new(gr) } @@ -742,4 +750,4 @@ register_graph("responsecompositor", pipeline_responsecompositor) register_graph("probregr", pipeline_probregr) register_graph("survtoregr", pipeline_survtoregr) register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime) -register_graph("survregr_PEM", pipeline_survtoregr_PEM) +register_graph("survtoregr_PEM", pipeline_survtoregr_PEM) From 5d6b61b03a725314bb5dd1f0284bcbd0d01c9519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20G=C3=B6swein?= Date: Thu, 14 Nov 2024 14:57:07 +0100 Subject: [PATCH 5/9] added col_role original_ids to regression tasks --- R/aaa.R | 3 ++- R/zzz.R | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/R/aaa.R b/R/aaa.R index 20925bec5..c874badaf 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -51,7 +51,8 @@ register_reflections = function() { x$task_col_roles$surv = x$task_col_roles$regr x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum") - x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time + x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids"))# for discrete time + x$task_col_roles$regr = unique(c(x$task_col_roles$regr, "original_ids")) x$task_properties$surv = x$task_properties$regr x$task_properties$dens = x$task_properties$regr diff --git a/R/zzz.R b/R/zzz.R index bad8118b7..ec6cba062 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -85,6 +85,7 @@ unregister_reflections = function() { x$task_col_roles$surv = NULL x$task_col_roles$dens = NULL x$task_col_roles$classif = setdiff(x$task_col_roles$classif, "original_ids") + x$task_col_roles$regr = setdiff(x$task_col_roles$regr, 'original_ids') x$task_properties$surv = NULL x$task_properties$dens = NULL From c19e4bbac3c5bef649683d80008ea8a1ac773232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20G=C3=B6swein?= Date: Thu, 14 Nov 2024 14:58:56 +0100 Subject: [PATCH 6/9] changed id column role to original_ids --- R/PipeOpTaskSurvRegrPEM.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R index 2ef9fdf8a..3edd0c781 100644 --- a/R/PipeOpTaskSurvRegrPEM.R +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -152,7 +152,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, target = "PEM_status") - task_PEM$set_col_roles("id", roles = "group") + task_PEM$set_col_roles("id", roles = "original_ids") list(task_PEM, data.table()) }, @@ -197,7 +197,7 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", long_data[, c("tstart", "interval", "obs_times") := NULL] task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, target = "PEM_status") - task_PEM$set_col_roles("id", roles = "group") + task_PEM$set_col_roles("id", roles = "original_ids") # map observed times back reps = table(long_data$id) From 717478d86ecbba6ee8fe3f949b969880c0bf98aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20G=C3=B6swein?= Date: Thu, 21 Nov 2024 23:18:03 +0100 Subject: [PATCH 7/9] added additional arguments to TaskSurvRegrPEM to enable more complex risk scenarios in the future, formula is now passed via the form argument during pipeline creation --- R/PipeOpPredRegrSurvPEM.R | 2 ++ R/PipeOpTaskSurvRegrPEM.R | 26 +++++++++++++++++++------- R/pipelines.R | 4 ++-- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/R/PipeOpPredRegrSurvPEM.R b/R/PipeOpPredRegrSurvPEM.R index 28d176763..0667bfb8c 100644 --- a/R/PipeOpPredRegrSurvPEM.R +++ b/R/PipeOpPredRegrSurvPEM.R @@ -60,6 +60,8 @@ PipeOpPredRegrSurvPEM = R6Class( # From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset))) rows_per_id = nrow(data) / length(unique(data$id)) + + # If 'single_event', 'cr', 'msm') surv = t(vapply(unique(data$id), function(unique_id) { exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))) }, numeric(rows_per_id))) diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R index 3edd0c781..f83075424 100644 --- a/R/PipeOpTaskSurvRegrPEM.R +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -92,7 +92,11 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", initialize = function(id = "trafotask_survregr_PEM") { param_set = ps( cut = p_uty(default = NULL), - max_time = p_dbl(0, default = NULL, special_vals = list(NULL)) + max_time = p_dbl(0, default = NULL, special_vals = list(NULL)), + censor_code = p_int(0L), + min_events = p_int(1L), + form = p_uty(tags = 'train') + #pammtools arguments: transitions etc. ) super$initialize( id = id, @@ -134,11 +138,19 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", assert(max_time > data[get(event_var) == 1, min(get(time_var))], "max_time must be greater than the minimum event time.") } - - form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") - - long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time) + + # To-Do: Extend to a more general formulation for competing risks and msm + # form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") + # To-Do: provide formula not as string, not via formula(...) + long_data = pammtools::as_ped(data = data, formula = formula(self$param_set$values$form), cut = cut, max_time = max_time) self$state$cut = attributes(long_data)$trafo_args$cut + # To-Do: + # extract other attributes (risks) for correct computation of predictions for competing risks and msm + # class(long_data) == ped_cr, ped_msmor ped + # self$state$risks = attributes(long_data)$risks + + + long_data = as.data.table(long_data) setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM @@ -175,9 +187,9 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", data[[event_var]] = 1 # update form - form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") + # form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") - long_data = as.data.table(pammtools::as_ped(data, formula = form, cut = cut)) + long_data = as.data.table(pammtools::as_ped(data, formula = formula(self$param_set$values$form), cut = cut)) setnames(long_data, old = "ped_status", new = "PEM_status") PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table diff --git a/R/pipelines.R b/R/pipelines.R index 7bce79af4..5fff52c3c 100644 --- a/R/pipelines.R +++ b/R/pipelines.R @@ -713,11 +713,11 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL, #' } #' @export pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL, - rhs = NULL, graph_learner = FALSE) { + rhs = NULL, graph_learner = FALSE, form = NULL) { # TODO: add assertions gr = mlr3pipelines::Graph$new() - gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time)) + gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time, form = form)) gr$add_pipeop(mlr3pipelines::po("learner", learner)) gr$add_pipeop(mlr3pipelines::po("nop")) gr$add_pipeop(mlr3pipelines::po("trafopred_regrsurv_PEM")) From 976d9c0f3dd731112d09d95df901b59b6497f49f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20G=C3=B6swein?= Date: Fri, 6 Dec 2024 10:08:07 +0100 Subject: [PATCH 8/9] form is now to be passed without quotation marks --- R/PipeOpTaskSurvRegrPEM.R | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R index f83075424..58011c051 100644 --- a/R/PipeOpTaskSurvRegrPEM.R +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -140,14 +140,10 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", } # To-Do: Extend to a more general formulation for competing risks and msm - # form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") - # To-Do: provide formula not as string, not via formula(...) - long_data = pammtools::as_ped(data = data, formula = formula(self$param_set$values$form), cut = cut, max_time = max_time) + # Issue: We pass form (e.g. Surv(time, status) ~ .) which currently serves to correctly transform the data into ped format + # but doesn't serve any other purpose yet. For ML learners, such as xgb, the covariate structure is passed to the pipeline via rhs not form. + long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time) self$state$cut = attributes(long_data)$trafo_args$cut - # To-Do: - # extract other attributes (risks) for correct computation of predictions for competing risks and msm - # class(long_data) == ped_cr, ped_msmor ped - # self$state$risks = attributes(long_data)$risks @@ -186,10 +182,8 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", status = data[[event_var]] data[[event_var]] = 1 - # update form - # form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".") - long_data = as.data.table(pammtools::as_ped(data, formula = formula(self$param_set$values$form), cut = cut)) + long_data = as.data.table(pammtools::as_ped(data, formula = self$param_set$values$form, cut = cut)) setnames(long_data, old = "ped_status", new = "PEM_status") PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table From 35745f0b58174a25ee25b9c4a737a693318985f6 Mon Sep 17 00:00:00 2001 From: markusgoeswein Date: Fri, 31 Jan 2025 15:13:34 +0100 Subject: [PATCH 9/9] resolve merge conflict with main, before merging --- R/pipelines.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/pipelines.R b/R/pipelines.R index 5fff52c3c..7f65716a7 100644 --- a/R/pipelines.R +++ b/R/pipelines.R @@ -750,4 +750,7 @@ register_graph("responsecompositor", pipeline_responsecompositor) register_graph("probregr", pipeline_probregr) register_graph("survtoregr", pipeline_survtoregr) register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime) +register_graph("survtoclassif_IPCW", pipeline_survtoclassif_IPCW) +register_graph("survtoclassif_vock", pipeline_survtoclassif_IPCW) # alias register_graph("survtoregr_PEM", pipeline_survtoregr_PEM) +