Skip to content

Commit

Permalink
refactor and fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Feb 21, 2024
1 parent db3378c commit 95c7400
Showing 1 changed file with 69 additions and 21 deletions.
90 changes: 69 additions & 21 deletions R/antolini.R
Original file line number Diff line number Diff line change
@@ -1,40 +1,65 @@
cindex = function(pred, meth = c("A", "H"), tiex = 0.5) {
all_times = pred$truth[, 1] # to differentiate with `times` below
n_obs = length(pred$truth)
pred_times = pred$truth[, 1] # to differentiate with `times` below
status = pred$truth[, 2]

# we write some code for Arrdist or VectorDistribution as in other measures
# in the end we just need the survival matrix
surv = pred$data$distr
times = as.numeric(colnames(surv))
risk = unname(pred$data$crank)

if (meth == "A") {
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
scores = diag(
extend_times(x = all_times, data = times, cdf = t(1 - surv), FALSE, FALSE)
# extend_times should return square matrix always for what we aim to do
)
surv_mat = extend_times(x = pred_times, data = times, cdf = t(1 - surv), FALSE, FALSE)
# add time points (important for indexing)
rownames(surv_mat) = pred_times

# extend_times should return square matrix always
# edge case to fix: `t(1 - surv)` does not return matrix if surv is 1x() or ()x1
} else {
scores = unname(pred$data$crank)
}

#browser()
pairs = data.frame(
ti = rep(all_times, length(all_times)),
di = rep(status, length(all_times)),
si = rep(scores, length(all_times)),
tj = rep(all_times, each = length(all_times)),
sj = rep(scores, each = length(all_times))
# i obs
i = seq_len(n_obs),
ti = rep(pred_times, n_obs),
di = rep(status, n_obs),
ri = rep(risk, n_obs),
# j obs
j = rep(seq_len(n_obs), each = n_obs),
tj = rep(pred_times, each = n_obs),
rj = rep(risk, each = n_obs)
)

comparable = function(t_i, t_j, d_i, c) d_i & t_i < t_j & t_i < c
comparable = function(ti, tj, di, cutoff) di & ti < tj & ti < cutoff

comp = comparable(pairs$ti, pairs$tj, pairs$di, Inf)
comp = comparable(pairs$ti, pairs$tj, pairs$di, cutoff = Inf)
if (meth == "A") {
conc = pairs[comp, "si"] < pairs[comp, "sj"]
# add the S(Ti,i) and S(Ti,j) for Antolini's C-index
# S is survival matrix with: rows => times, cols => obs
surv_ii = sapply(1:nrow(pairs), function(row_index) {
row = pairs[row_index, ]
unname(surv_mat[as.character(row[, "ti"]), row[, "i"]])
}) # survival of i-th obs at T(i) event time point
# same as:
surv_ii2 = rep(diag(surv_mat), times = n_obs) # much faster, keep that!
testthat::expect_equal(surv_ii, surv_ii2)

surv_ij = sapply(1:nrow(pairs), function(row_index) {
row = pairs[row_index, ]
unname(surv_mat[as.character(row[, "ti"]), row[, "j"]])
}) # survival of j-th obs at T(i) event time point

# fill in the survival probability columns
pairs = cbind(pairs, sii = surv_ii, sij = surv_ij)

conc = pairs[comp, "sii"] < pairs[comp, "sij"]
conc = sum(conc) + sum((pairs[comp, "sii"] == pairs[comp, "sij"])) * tiex
} else {
conc = pairs[comp, "si"] > pairs[comp, "sj"]
conc = pairs[comp, "ri"] > pairs[comp, "rj"]
conc = sum(conc) + sum((pairs[comp, "ri"] == pairs[comp, "rj"])) * tiex
}
conc = sum(conc) + sum((pairs[comp, "si"] == pairs[comp, "sj"])) * tiex

conc / sum(comp)
}
Expand All @@ -45,11 +70,34 @@ set.seed(42)
t = tsk("rats")
s = partition(t)
p = lrn("surv.coxph")$train(t, s$train)$predict(t, s$test)
p$score(msr("surv.rcll"))

# check Harrell
p$score()
cindex(p, "H")

p$score(msr("surv.cindex", tiex = 1))
cindex(p, "H", tiex = 1)

# check Antolini
cindex(p, "A", 0.5)

cindex(p, "H", 0.8) - p$score(msr("surv.cindex", tiex = 0.8))
cindex(p, "H", 0.8) - p$score(msr("surv.cindex", tiex = 0.8)) < 1e-6
microbenchmark::microbenchmark(cindex(p, "H", 0.5), p$score())
cindex(p, "H", 0.8) - p$score(msr("surv.cindex", tiex = 0.8)) # < 1e-6
microbenchmark::microbenchmark(cindex(p, "H", 0.5), p$score()) # faster

# benchmark check
set.seed(42)
bmr = benchmark(benchmark_grid(
tasks = tsks(c("rats", "gbcs", "grace")),
learners = lrn("surv.coxph"),
resamplings = rsmp("cv", folds = 3)
))
bmr$score()$surv.cindex # > 0.7

# slow! but all all Antolini's C > 0.7 (correct implementation)
# some are equal to Harrell's C, some are not? (due 1-1 correspondence
# they all should be?)
for (i in 1:3) {
for (p in bmr$resample_results$resample_result[[i]]$predictions()) {
print(cindex(pred = p, meth = "A")) # "H" => checking Harrell's C is the same as above (YES)
}
}

0 comments on commit 95c7400

Please sign in to comment.