Skip to content

Commit

Permalink
Support XGBoostRanker and XGBoostRegressor (#10505)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jul 3, 2024
1 parent f79a4e9 commit e3ba9fc
Show file tree
Hide file tree
Showing 25 changed files with 1,296 additions and 472 deletions.
9 changes: 9 additions & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,15 @@
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<checkMultipleScalaVersions>true</checkMultipleScalaVersions>
<failOnMultipleScalaVersions>false</failOnMultipleScalaVersions>
<recompileMode>incremental</recompileMode>
<args>
<arg>-Ywarn-unused:imports,locals,patvars,privates</arg>
</args>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ object BoostFromPrediction {
testMat.setBaseMargin(testPred)

System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ object CrossValidation {
// set additional eval_metrics
val metrics: Array[String] = null

val evalHist: Array[String] =
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object CustomObjective {

val round = 2
// train a model
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap,
obj = new LogRegObj, eval = new EvalError)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ object ExternalMemory {
testMat.setBaseMargin(testPred)

System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ object GeneralizedLinearModel {
watches += "train" -> trainMat
watches += "test" -> testMat

val round = 4
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
val predicts = booster.predict(testMat)
val eval = new CustomEval
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object PredictLeafIndices {

// predict all trees
val leafIndex2 = booster.predictLeaf(testMat, 0)
for (leafs <- leafIndex) {
for (leafs <- leafIndex2) {
println(java.util.Arrays.toString(leafs))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
"classIndex")

val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
val Array(train, eval1, _, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))

/**
* setup spark.scheduler.barrier.maxConcurrentTasksCheck.interval and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types.{DataType, FloatType, IntegerType}
import org.apache.spark.sql.vectorized.ColumnarBatch

import ml.dmlc.xgboost4j.java.CudfColumnBatch
Expand Down Expand Up @@ -65,22 +66,23 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema

def selectCol(c: Param[String]) = {
def selectCol(c: Param[String], targetType: DataType = FloatType) = {
// TODO support numeric types
if (estimator.isDefinedNonEmpty(c)) {
selectedCols.append(estimator.castToFloatIfNeeded(schema, estimator.getOrDefault(c)))
selectedCols.append(estimator.castIfNeeded(schema, estimator.getOrDefault(c), targetType))
}
}

Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol).foreach(selectCol)
Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol)
.foreach(p => selectCol(p))
estimator match {
case p: HasGroupCol => selectCol(p.groupCol)
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
case _ =>
}

// TODO support array/vector feature
estimator.getFeaturesCols.foreach { name =>
val col = estimator.castToFloatIfNeeded(dataset.schema, name)
val col = estimator.castIfNeeded(dataset.schema, name)
selectedCols.append(col)
}
val input = dataset.select(selectedCols: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.functions.{col, udf}
import org.json4s.DefaultFormats

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{binaryClassificationObjs, multiClassificationObjs}
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}

class XGBoostClassifier(override val uid: String,
private[spark] val xgboostParams: Map[String, Any])
Expand All @@ -51,7 +51,7 @@ class XGBoostClassifier(override val uid: String,
// multiClassificationObjs
val obj = if (isSet(objective)) {
val tmpObj = getObjective
val supportedObjs = binaryClassificationObjs.toSeq ++ multiClassificationObjs.toSeq
val supportedObjs = BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq
require(supportedObjs.contains(tmpObj),
s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
Some(tmpObj)
Expand All @@ -72,7 +72,7 @@ class XGBoostClassifier(override val uid: String,

// objective is set explicitly.
if (obj.isDefined) {
if (multiClassificationObjs.contains(getObjective)) {
if (MULTICLASSIFICATION_OBJS.contains(getObjective)) {
numberClasses = inferNumClasses
setNumClass(numberClasses)
} else {
Expand Down Expand Up @@ -105,18 +105,16 @@ class XGBoostClassifier(override val uid: String,

override protected def createModel(booster: Booster, summary: XGBoostTrainingSummary):
XGBoostClassificationModel = {
new XGBoostClassificationModel(uid, numberClasses, booster, Some(summary))
new XGBoostClassificationModel(uid, numberClasses, booster, Option(summary))
}

}

object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
private val _uid = Identifiable.randomUID("xgbc")

override def load(path: String): XGBoostClassifier = super.load(path)
}

class XGBoostClassificationModel(
class XGBoostClassificationModel private[ml](
val uid: String,
val numClasses: Int,
val nativeBooster: Booster,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, ML
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types._

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
Expand Down Expand Up @@ -84,7 +84,7 @@ private[spark] trait PluginMixin {
}
}

/** Visiable for testing */
/** Visible for testing */
protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin

protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
Expand All @@ -101,17 +101,19 @@ private[spark] trait XGBoostEstimator[
protected val logger = LogFactory.getLog("XGBoostSpark")

/**
* Pre-convert input double data to floats to align with XGBoost's internal float-based
* operations to save memory usage.
* Cast the field in schema to the desired data type.
*
* @param dataset the input dataset
* @param name which column will be casted to float if possible.
* @param dataset the input dataset
* @param name which column will be casted to float if possible.
* @param targetType the targetd data type
* @return Dataset
*/
private[spark] def castToFloatIfNeeded(schema: StructType, name: String): Column = {
if (!schema(name).dataType.isInstanceOf[FloatType]) {
private[spark] def castIfNeeded(schema: StructType,
name: String,
targetType: DataType = FloatType): Column = {
if (!(schema(name).dataType == targetType)) {
val meta = schema(name).metadata
col(name).as(name, meta).cast(FloatType)
col(name).as(name, meta).cast(targetType)
} else {
col(name)
}
Expand Down Expand Up @@ -180,20 +182,20 @@ private[spark] trait XGBoostEstimator[
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema

def selectCol(c: Param[String]) = {
def selectCol(c: Param[String], targetType: DataType = FloatType) = {
if (isDefinedNonEmpty(c)) {
// Validation col should be a boolean column.
if (c == featuresCol) {
selectedCols.append(col($(c)))
} else {
selectedCols.append(castToFloatIfNeeded(schema, $(c)))
selectedCols.append(castIfNeeded(schema, $(c), targetType))
}
}
}

Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(selectCol)
Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(p => selectCol(p, FloatType))
this match {
case p: HasGroupCol => selectCol(p.groupCol)
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
case _ =>
}
val input = repartitionIfNeeded(dataset.select(selectedCols: _*))
Expand All @@ -210,10 +212,10 @@ private[spark] trait XGBoostEstimator[
val label = row.getFloat(columnIndexes.labelId)
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
val group = columnIndexes.groupId.map(row.getFloat).getOrElse(-1.0f)
val group = columnIndexes.groupId.map(row.getInt).getOrElse(-1)
// To make "0" meaningful, we convert sparse vector if possible to dense to create DMatrix.
val values = features.toArray.map(_.toFloat)
XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin)
XGBLabeledPoint(label, values.length, null, values, weight, group, baseMargin)
}
}

Expand All @@ -232,18 +234,57 @@ private[spark] trait XGBoostEstimator[

val missing = getMissing

// transform the labeledpoint to get margins and build DMatrix
// Transform the labeledpoint to get margins/groups and build DMatrix
// TODO support basemargin for multiclassification
// TODO, move it into JNI
// TODO and optimization, move it into JNI.
def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = {
val dmatrix = if (columnIndices.marginId.isDefined) {
val trainMargins = new mutable.ArrayBuilder.ofFloat
val dmatrix = if (columnIndices.marginId.isDefined || columnIndices.groupId.isDefined) {
val margins = new mutable.ArrayBuilder.ofFloat
val groups = new mutable.ArrayBuilder.ofInt
val groupWeights = new mutable.ArrayBuilder.ofFloat
var prevGroup = -101010
var prevWeight = -1.0f
var groupSize = 0
val transformedIter = iter.map { labeledPoint =>
trainMargins += labeledPoint.baseMargin
if (columnIndices.marginId.isDefined) {
margins += labeledPoint.baseMargin
}
if (columnIndices.groupId.isDefined) {
if (prevGroup != labeledPoint.group) {
// starting with new group
if (prevGroup != -101010) {
// write the previous group
groups += groupSize
groupWeights += prevWeight
}
groupSize = 1
prevWeight = labeledPoint.weight
prevGroup = labeledPoint.group
} else {
// for the same group
if (prevWeight != labeledPoint.weight) {
throw new IllegalArgumentException("the instances in the same group have to be" +
s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
}
groupSize = groupSize + 1
}
}
labeledPoint
}
val dm = new DMatrix(transformedIter, null, missing)
dm.setBaseMargin(trainMargins.result())
columnIndices.marginId.foreach(_ => dm.setBaseMargin(margins.result()))
if (columnIndices.groupId.isDefined) {
if (prevGroup != -101011) {
// write the last group
groups += groupSize
groupWeights += prevWeight
}
dm.setGroup(groups.result())
// The new DMatrix() will set the weights for each instance. But ranking requires
// 1 weight for each group, so need to reset the weight.
// This is definitely optimized by moving setting group/base margin into JNI.
dm.setWeight(groupWeights.result())
}
dm
} else {
new DMatrix(iter, null, missing)
Expand Down Expand Up @@ -327,14 +368,6 @@ private[spark] trait XGBoostEstimator[
SparkUtils.checkNumericType(schema, $(baseMarginCol))
}

// TODO Move this to XGBoostRanker
// this match {
// case p: HasGroupCol =>
// if (isDefined(p.groupCol) && $(p.groupCol).nonEmpty) {
// SparkUtils.checkNumericType(schema, p.getGroupCol)
// }
// }

val taskCpus = dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", 1)
if (isDefined(nthread)) {
require(getNthread <= taskCpus,
Expand Down Expand Up @@ -519,6 +552,14 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
}

override def write: MLWriter = new XGBoostModelWriter[XGBoostModel[_]](this)

protected def predictSingleInstance(features: Vector): Array[Float] = {
if (nativeBooster == null) {
throw new IllegalArgumentException("The model has not been trained")
}
val dm = new DMatrix(Iterator(features.asXGB), null, getMissing)
nativeBooster.predict(data = dm)(0)
}
}

/**
Expand Down Expand Up @@ -560,3 +601,22 @@ private[spark] abstract class XGBoostModelReader[M <: XGBoostModel[M]] extends M
}
}
}

// Trait for Ranker and Regressor Model
private[spark] trait RankerRegressorBaseModel[M <: XGBoostModel[M]] extends XGBoostModel[M] {

override protected[spark] def postTransform(dataset: Dataset[_],
pred: PredictedColumns): Dataset[_] = {
var output = super.postTransform(dataset, pred)
if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
originalPrediction(0).toDouble
}
output = output
.withColumn($(predictionCol), predictUDF(col(TMP_TRANSFORMED_COL)))
.drop(TMP_TRANSFORMED_COL)
}
output
}

}
Loading

0 comments on commit e3ba9fc

Please sign in to comment.