Skip to content

Commit

Permalink
Stick to linter
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Dec 8, 2023
1 parent aadeb7e commit a69ed89
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ convert.labels <- function(labels, objective_name) {

# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {

# cannot stratify if label is NULL
if (isTRUE(stratified) && is.null(label)) {
warning("Will use unstratified splitting (no labels provided)")
Expand Down
4 changes: 2 additions & 2 deletions R-package/R/xgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing

check.custom.obj()
check.custom.eval()

# AFT uses 'label_lower/upper_bound' instead of 'label'
if (is.character(params$objective) && params$objective == 'survival:aft') {
if (!inherits(data, 'xgb.DMatrix')) {
stop("Objective 'survival:aft' requires the data to be an 'xgb.DMatrix'.")
}
if (is.null(getinfo(data, name = 'label_lower_bound')) ||
if (is.null(getinfo(data, name = 'label_lower_bound')) ||
is.null(getinfo(data, name = 'label_upper_bound'))) {
stop("Objective 'survival:aft' requires 'label_lower_bound' and 'label_upper_bound.")
}
Expand Down
18 changes: 9 additions & 9 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -391,29 +391,29 @@ test_that("xgb.cv works with stratified folds", {
test_that("xgb.cv works for AFT", {
X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
dtrain <- xgb.DMatrix(X, nthread = n_threads)

setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))

params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)

# data must be xgb.DMatrix in aft case
expect_error(
xgb.cv(
params = params,
data = X,
nround = 5L,
nfold = 4L,
params = params,
data = X,
nround = 5L,
nfold = 4L,
nthread = n_threads,
label = c(2, 3, 0, 4)
)
)

# automatic stratified splitting is turned off
expect_warning(
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, nthread = n_threads)
)

# this works without any issue
expect_no_warning(
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, stratified = FALSE)
Expand Down

0 comments on commit a69ed89

Please sign in to comment.