Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jan 7, 2025
1 parent c1010b3 commit c7018fe
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,17 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
val nthread = estimator.getNthread
val missing = estimator.getMissing

val useExtMem = estimator.getUseExternalMemory
val extMemPath = if (useExtMem) {
Some(dataset.sparkSession.conf.get("spark.local.dir", "/tmp"))
} else None

/** build QuantileDMatrix on the executor side */
def buildQuantileDMatrix(input: Iterator[Table],
ref: Option[QuantileDMatrix] = None): QuantileDMatrix = {
val cachePath = Option("/tmp/")

val (iterator, useExtMem) = cachePath match {
val (iterator, useExtMem) = extMemPath match {
case Some(_) =>
(new ExternalMemoryIterator(input, indices, cachePath), true)
(new ExternalMemoryIterator(input, indices, extMemPath), true)
case None =>
(input.map { table =>
withResource(new GpuColumnBatch(table)) { batch =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
.setFeaturesCol(features)
.setLabelCol(label)
.setDevice("cuda")
.setUseExternalMemory(true)
val model = estimator.fit(df)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,24 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

final def getFeatureTypes: Array[String] = $(featureTypes)

final val useExternalMemory = new BooleanParam(this, "useExternalMemory", "Whether to use " +
"the external memory or not when building QuantileDMatrix. Please note that " +
"useExternalMemory is useful only when `device` is set to `cuda` or `gpu`. When " +
"useExternalMemory is enabled, the directory specified by spark.local.dir if set will be " +
"used to cache the temporary files, if spark.local.dir is not set, the /tmp directory " +
"will be used.")

final def getUseExternalMemory: Boolean = $(useExternalMemory)

setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
featuresCols -> Array.empty, customObj -> null, customEval -> null,
featureNames -> Array.empty, featureTypes -> Array.empty)
featureNames -> Array.empty, featureTypes -> Array.empty, useExternalMemory -> false)

addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
forceRepartition, featuresCols, customEval, customObj, featureTypes, featureNames)
forceRepartition, featuresCols, customEval, customObj, featureTypes, featureNames,
useExternalMemory)

final def getNumWorkers: Int = $(numWorkers)

Expand Down Expand Up @@ -224,6 +234,8 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]

def setUseExternalMemory(value: Boolean): T = set(useExternalMemory, value).asInstanceOf[T]

protected[spark] def featureIsArrayType(schema: StructType): Boolean =
schema(getFeaturesCol).dataType.isInstanceOf[ArrayType]

Expand Down

0 comments on commit c7018fe

Please sign in to comment.