From 5204ab84c968e40a540f2adb45d8a82b333b9836 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 10 Aug 2023 10:33:39 -0500 Subject: [PATCH] [R] remove default values in internal booster manipulation functions --- R-package/R/callbacks.R | 4 ++-- R-package/R/xgb.Booster.R | 14 +++++++++----- R-package/R/xgb.cv.R | 7 ++++++- R-package/R/xgb.load.R | 11 ++++++++--- R-package/R/xgb.train.R | 9 +++++++-- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/R-package/R/callbacks.R b/R-package/R/callbacks.R index d2ee59476e6e..7265967b2ad9 100644 --- a/R-package/R/callbacks.R +++ b/R-package/R/callbacks.R @@ -511,7 +511,7 @@ cb.cv.predict <- function(save_models = FALSE) { if (save_models) { env$basket$models <- lapply(env$bst_folds, function(fd) { xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1 - xgb.Booster.complete(xgb.handleToBooster(fd$bst), saveraw = TRUE) + xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE) }) } } @@ -659,7 +659,7 @@ cb.gblinear.history <- function(sparse = FALSE) { } else { # xgb.cv: cf <- vector("list", length(env$bst_folds)) for (i in seq_along(env$bst_folds)) { - dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst)) + dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL)) cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE)) if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector") } diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 6a53577e990f..5ffbbc31c869 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -1,7 +1,6 @@ # Construct an internal xgboost Booster and return a handle to it. # internal utility function -xgb.Booster.handle <- function(params = list(), cachelist = list(), - modelfile = NULL, handle = NULL) { +xgb.Booster.handle <- function(params, cachelist, modelfile, handle) { if (typeof(cachelist) != "list" || !all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) { stop("cachelist must be a list of xgb.DMatrix objects") @@ -44,7 +43,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), # Convert xgb.Booster.handle to xgb.Booster # internal utility function -xgb.handleToBooster <- function(handle, raw = NULL) { +xgb.handleToBooster <- function(handle, raw) { bst <- list(handle = handle, raw = raw) class(bst) <- "xgb.Booster" return(bst) @@ -129,7 +128,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { stop("argument type must be xgb.Booster") if (is.null.handle(object$handle)) { - object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle) + object$handle <- xgb.Booster.handle( + params = list(), + cachelist = list(), + modelfile = object$raw, + handle = object$handle + ) } else { if (is.null(object$raw) && saveraw) { object$raw <- xgb.serialize(object$handle) @@ -475,7 +479,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA #' @export predict.xgb.Booster.handle <- function(object, ...) { - bst <- xgb.handleToBooster(object) + bst <- xgb.handleToBooster(handle = object, raw = NULL) ret <- predict(bst, ...) return(ret) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 788638921fae..24c1b3f3cb90 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -202,7 +202,12 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing dtrain <- slice(dall, unlist(folds[-k])) else dtrain <- slice(dall, train_folds[[k]]) - handle <- xgb.Booster.handle(params, list(dtrain, dtest)) + handle <- xgb.Booster.handle( + params = params, + cachelist = list(dtrain, dtest), + modelfile = NULL, + handle = NULL + ) list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]]) }) rm(dall) diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index d98041908eb4..cfbf0b2d82bf 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -35,7 +35,12 @@ xgb.load <- function(modelfile) { if (is.null(modelfile)) stop("xgb.load: modelfile cannot be NULL") - handle <- xgb.Booster.handle(modelfile = modelfile) + handle <- xgb.Booster.handle( + params = list(), + cachelist = list(), + modelfile = modelfile, + handle = NULL + ) # re-use modelfile if it is raw so we do not need to serialize if (typeof(modelfile) == "raw") { warning( @@ -45,9 +50,9 @@ xgb.load <- function(modelfile) { " `xgb.unserialize` instead. " ) ) - bst <- xgb.handleToBooster(handle, modelfile) + bst <- xgb.handleToBooster(handle = handle, raw = modelfile) } else { - bst <- xgb.handleToBooster(handle, NULL) + bst <- xgb.handleToBooster(handle = handle, raw = NULL) } bst <- xgb.Booster.complete(bst, saveraw = TRUE) return(bst) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 729475945c4d..7fe64ab34bdf 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -363,8 +363,13 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), is_update <- NVL(params[['process_type']], '.') == 'update' # Construct a booster (either a new one or load from xgb_model) - handle <- xgb.Booster.handle(params, append(watchlist, dtrain), xgb_model) - bst <- xgb.handleToBooster(handle) + handle <- xgb.Booster.handle( + params = params, + cachelist = append(watchlist, dtrain), + modelfile = xgb_model, + handle = NULL + ) + bst <- xgb.handleToBooster(handle = handle, raw = NULL) # extract parameters that can affect the relationship b/w #trees and #iterations num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)