Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 11, 2024
1 parent e9ede8b commit 7b30265
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ package ml.dmlc.xgboost4j.scala.spark.params

import org.apache.spark.ml.param._

import scala.collection.immutable.HashSet

private[spark] object DartBoosterParams {
val supportedNormalizeType = HashSet("tree", "forest")
}

/**
* Dart booster parameters, more details can be found at
* https://xgboost.readthedocs.io/en/stable/parameter.html#additional-parameters-for-dart-booster-booster-dart
*/
private[spark] trait DartBoosterParams extends Params {

final val sampleType = new Param[String](this, "sampleType", "Type of sampling algorithm, " +
Expand All @@ -27,7 +37,7 @@ private[spark] trait DartBoosterParams extends Params {

final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
" algorithm, options: {'tree', 'forest'}",
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
(value: String) => DartBoosterParams.supportedNormalizeType.contains(value))

final def getNormalizeType: String = $(normalizeType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@

package ml.dmlc.xgboost4j.scala.spark.params

import org.apache.spark.ml.param.{IntParam, Params}
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}

private[spark] trait InferenceParams extends Params {

/**
* batch size of inference iteration
*/
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")

/** @group getParam */
final def getInferBatchSize: Int = $(inferBatchSize)

setDefault(inferBatchSize, 32 << 10)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package ml.dmlc.xgboost4j.scala.spark.params

import org.apache.spark.ml.param._

import scala.collection.immutable.HashSet


/**
* TreeBoosterParams defines the XGBoost TreeBooster parameters for Spark
*
Expand Down Expand Up @@ -228,21 +231,15 @@ private[spark] trait TreeBoosterParams extends Params {
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0,
subsample -> 1, samplingMethod -> "uniform", colsampleBytree -> 1, colsampleBylevel -> 1,
colsampleBynode -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", scalePosWeight -> 1,
processType->"default", growPolicy->"depthwise", maxLeaves->0, maxBins->256, numParallelTree->1,
maxCachedHistNode->65536)


/**
* The device for running XGBoost algorithms, options: cpu, cuda
*/
final val device = new Param[String](
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
(value: String) => BoosterParams.supportedDevices.contains(value)
)

final def getDevice: String = $(device)
processType -> "default", growPolicy -> "depthwise", maxLeaves -> 0, maxBins -> 256,
numParallelTree -> 1, maxCachedHistNode -> 65536)

}

private[spark] object BoosterParams {

val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")

val supportedUpdaters = HashSet("grow_colmaker", "grow_histmaker", "grow_quantile_histmaker",
"grow_gpu_hist", "grow_gpu_approx", "sync", "refresh", "prune")
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._

trait HasInferenceSizeParams extends Params {
/**
* batch size in rows to be grouped for inference
*/
final val inferBatchSize = new IntParam(this, "inferBatchSize", "batch size in rows " +
"to be grouped for inference",
ParamValidators.gtEq(1))

/** @group getParam */
final def getInferBatchSize: Int = $(inferBatchSize)
}

trait HasLeafPredictionCol extends Params {
/**
* Param for leaf prediction column name.
Expand Down Expand Up @@ -58,7 +70,6 @@ trait HasBaseMarginCol extends Params {
/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)

def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
}

trait HasGroupCol extends Params {
Expand All @@ -67,10 +78,8 @@ trait HasGroupCol extends Params {

/** @group getParam */
final def getGroupCol: String = $(groupCol)

}


/**
* Trait for shared param featuresCols.
*/
Expand Down Expand Up @@ -101,7 +110,8 @@ trait HasFeaturesCols extends Params {
*/
private[spark] trait SparkParams[T <: Params] extends Params
with HasFeaturesCol with HasLabelCol with HasBaseMarginCol with HasWeightCol
with HasPredictionCol with HasLeafPredictionCol with HasContribPredictionCol {
with HasPredictionCol with HasLeafPredictionCol with HasContribPredictionCol
with HasInferenceSizeParams {

final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost",
ParamValidators.gtEq(1))
Expand All @@ -111,21 +121,29 @@ private[spark] trait SparkParams[T <: Params] extends Params
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
ParamValidators.gtEq(1))

setDefault(numRound -> 100, numWorkers -> 1)
setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10))

final def getNumWorkers: Int = $(numWorkers)

def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]

def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T]

def setFeaturesCol(value: String): T = set(featuresCol, value).asInstanceOf[T]

def setLabelCol(value: String): T = set(labelCol, value).asInstanceOf[T]

def setBaseMarginCol(value: String): T = set(baseMarginCol, value).asInstanceOf[T]

def setWeightCol(value: String): T = set(weightCol, value).asInstanceOf[T]

def setPredictionCol(value: String): T = set(predictionCol, value).asInstanceOf[T]

def setLeafPredictionCol(value: String): T = set(leafPredictionCol, value).asInstanceOf[T]

def setContribPredictionCol(value: String): T = set(contribPredictionCol, value).asInstanceOf[T]

def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T]
}

/**
Expand Down

0 comments on commit 7b30265

Please sign in to comment.