diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml index 8463afe9b049..b9b576c6cbcb 100644 --- a/jvm-packages/scalastyle-config.xml +++ b/jvm-packages/scalastyle-config.xml @@ -210,7 +210,7 @@ This file is divided into 3 sections: java,scala,3rdParty,dmlc javax?\..* scala\..* - (?!ml\.dmlc\.xgboost4j\.).* + (?!ml\.dmlc\.xgboost4j).* ml.dmlc.xgboost4j.* diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/resources/META-INF/services/ml.dmlc.xgboost4j.scala.spark.XGBoostPlugin b/jvm-packages/xgboost4j-spark-gpu/src/main/resources/META-INF/services/ml.dmlc.xgboost4j.scala.spark.XGBoostPlugin index 8427404c5ae6..11a1de8bf147 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/resources/META-INF/services/ml.dmlc.xgboost4j.scala.spark.XGBoostPlugin +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/resources/META-INF/services/ml.dmlc.xgboost4j.scala.spark.XGBoostPlugin @@ -1 +1 @@ -ml.dmlc.xgboost4j.scala.spark.GPUXGBoostPlugin +ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala index 67162cfb342d..93a773829f43 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala @@ -21,7 +21,7 @@ import _root_.scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, QuantileDMatrix => JQuantileDMatrix, XGBoostError} class QuantileDMatrix private[scala]( - private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) { + private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) { /** * Create QuantileDMatrix from iterator based on the cuda array interface diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GPUXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GPUXGBoostPlugin.scala deleted file mode 100644 index cd7bc965c9c7..000000000000 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GPUXGBoostPlugin.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - Copyright (c) 2024 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters.seqAsJavaListConverter - -import com.nvidia.spark.rapids.ColumnarRdd -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, Dataset} - -import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch} -import ml.dmlc.xgboost4j.scala.QuantileDMatrix - -private[spark] case class ColumnIndices( - labelId: Int, - featuresId: Seq[Int], - weightId: Option[Int], - marginId: Option[Int], - groupId: Option[Int]) - -class GPUXGBoostPlugin extends XGBoostPlugin { - - /** - * Whether the plugin is enabled or not, if not enabled, fallback - * to the regular CPU pipeline - * - * @param dataset the input dataset - * @return Boolean - */ - override def isEnabled(dataset: Dataset[_]): Boolean = { - val conf = dataset.sparkSession.conf - val hasRapidsPlugin = conf.get("spark.sql.extensions", "").split(",").contains( - "com.nvidia.spark.rapids.SQLExecPlugin") - val rapidsEnabled = conf.get("spark.rapids.sql.enabled", "false").toBoolean - hasRapidsPlugin && rapidsEnabled - } - - /** - * Convert Dataset to RDD[Watches] which will be fed into XGBoost - * - * @param estimator which estimator to be handled. - * @param dataset to be converted. - * @return RDD[Watches] - */ - override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( - estimator: XGBoostEstimator[T, M], - dataset: Dataset[_]): RDD[Watches] = { - println("buildRddWatches ---") - - // TODO, check if the feature in featuresCols is numeric. - - val features = estimator.getFeaturesCols - val maxBin = estimator.getMaxBins - val nthread = estimator.getNthread - // TODO cast features to float if possible - - val label = estimator.getLabelCol - val missing = Float.NaN - - val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty - (features.toSeq ++ Seq(estimator.getLabelCol)).foreach {name => - val col = estimator.castToFloatIfNeeded(dataset.schema, name) - selectedCols.append(col) - } - var input = dataset.select(selectedCols: _*) - input = input.repartition(estimator.getNumWorkers) - - val schema = input.schema - val indices = ColumnIndices( - schema.fieldIndex(label), - features.map(schema.fieldIndex), - None, None, None - ) - - ColumnarRdd(input).mapPartitions { iter => - val colBatchIter = iter.map { table => - withResource(new GpuColumnBatch(table, null)) { batch => - new CudfColumnBatch( - batch.slice(indices.featuresId.map(Integer.valueOf).asJava), - batch.slice(indices.labelId), - batch.slice(indices.weightId.getOrElse(-1)), - batch.slice(indices.marginId.getOrElse(-1))); - } - } - - val dm = new QuantileDMatrix(colBatchIter, missing, maxBin, nthread) - Iterator.single(new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)) - } - } - - /** Executes the provided code block and then closes the resource */ - def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = { - try { - block(r) - } finally { - r.close() - } - } - -} diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala new file mode 100644 index 000000000000..5cef49799fc5 --- /dev/null +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -0,0 +1,152 @@ +/* + Copyright (c) 2024 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters.seqAsJavaListConverter + +import ai.rapids.cudf.Table +import com.nvidia.spark.rapids.ColumnarRdd +import org.apache.spark.ml.param.Param +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset} + +import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch} +import ml.dmlc.xgboost4j.scala.QuantileDMatrix +import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol + +/** + * GpuXGBoostPlugin is the XGBoost plugin which leverage spark-rapids + * to accelerate the XGBoost from ETL to train. + */ +class GpuXGBoostPlugin extends XGBoostPlugin { + + /** + * Whether the plugin is enabled or not, if not enabled, fallback + * to the regular CPU pipeline + * + * @param dataset the input dataset + * @return Boolean + */ + override def isEnabled(dataset: Dataset[_]): Boolean = { + val conf = dataset.sparkSession.conf + val hasRapidsPlugin = conf.get("spark.sql.extensions", "").split(",").contains( + "com.nvidia.spark.rapids.SQLExecPlugin") + val rapidsEnabled = conf.get("spark.rapids.sql.enabled", "false").toBoolean + hasRapidsPlugin && rapidsEnabled + } + + // TODO, support numeric type + private def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( + estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): Dataset[_] = { + + // Columns to be selected for XGBoost training + val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty + val schema = dataset.schema + + def selectCol(c: Param[String]) = { + // TODO support numeric types + if (estimator.isDefinedNonEmpty(c)) { + selectedCols.append(estimator.castToFloatIfNeeded(schema, estimator.getOrDefault(c))) + } + } + + Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol).foreach(selectCol) + estimator match { + case p: HasGroupCol => selectCol(p.groupCol) + case _ => + } + + // TODO support array/vector feature + estimator.getFeaturesCols.foreach { name => + val col = estimator.castToFloatIfNeeded(dataset.schema, name) + selectedCols.append(col) + } + val input = dataset.select(selectedCols: _*) + estimator.repartitionIfNeeded(input) + } + + private def validate[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( + estimator: XGBoostEstimator[T, M], + dataset: Dataset[_]): Unit = { + require(estimator.getTreeMethod == "gpu_hist" || estimator.getDevice != "cpu", + "Using Spark-Rapids to accelerate XGBoost must set device=cuda") + } + + /** + * Convert Dataset to RDD[Watches] which will be fed into XGBoost + * + * @param estimator which estimator to be handled. + * @param dataset to be converted. + * @return RDD[Watches] + */ + override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( + estimator: XGBoostEstimator[T, M], + dataset: Dataset[_]): RDD[Watches] = { + + validate(estimator, dataset) + + val train = preprocess(estimator, dataset) + val schema = train.schema + + val indices = estimator.buildColumnIndices(schema) + + val maxBin = estimator.getMaxBins + val nthread = estimator.getNthread + val missing = estimator.getMissing + + /** build QuantilDMatrix on the executor side */ + def buildQuantileDMatrix(iter: Iterator[Table]): QuantileDMatrix = { + val colBatchIter = iter.map { table => + withResource(new GpuColumnBatch(table, null)) { batch => + new CudfColumnBatch( + batch.slice(indices.featureIds.get.map(Integer.valueOf).asJava), + batch.slice(indices.labelId), + batch.slice(indices.weightId.getOrElse(-1)), + batch.slice(indices.marginId.getOrElse(-1))); + } + } + new QuantileDMatrix(colBatchIter, missing, maxBin, nthread) + } + + estimator.getEvalDataset().map { evalDs => + val evalProcessed = preprocess(estimator, evalDs) + ColumnarRdd(train.toDF()).zipPartitions(ColumnarRdd(evalProcessed.toDF())) { + (trainIter, evalIter) => + val trainDM = buildQuantileDMatrix(trainIter) + val evalDM = buildQuantileDMatrix(evalIter) + Iterator.single(new Watches(Array(trainDM, evalDM), + Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)) + } + }.getOrElse( + ColumnarRdd(train.toDF()).mapPartitions { iter => + val dm = buildQuantileDMatrix(iter) + Iterator.single(new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)) + } + ) + } + + /** Executes the provided code block and then closes the resource */ + def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = { + try { + block(r) + } finally { + r.close() + } + } + +} diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala index 7f87e4319a8c..215fc81f3303 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala @@ -41,10 +41,10 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { var Array(trainDf, validationDf) = df.randomSplit(Array(0.8, 0.2), seed = 1) - trainDf = trainDf.withColumn("validation", lit(false)) - validationDf = validationDf.withColumn("validationDf", lit(true)) +// trainDf = trainDf.withColumn("validation", lit(false)) +// validationDf = validationDf.withColumn("validationDf", lit(true)) - df = trainDf.union(validationDf) +// df = trainDf.union(validationDf) // // // Assemble the feature columns into a single vector column // val assembler = new VectorAssembler() @@ -63,7 +63,9 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { // .setBaseMarginCol("base_margin") .setFeaturesCol(features) .setLabelCol(labelCol) - .setValidationIndicatorCol("validation") + .setDevice("cuda") + .setEvalDataset(validationDf) +// .setValidationIndicatorCol("validation") // .setPredictionCol("") .setRawPredictionCol("") .setProbabilityCol("xxxx") @@ -76,16 +78,16 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { println(loadedEst.getNumRound) println(loadedEst.getMaxDepth) - val model = loadedEst.fit(df) + val model = est.fit(trainDf) println("-----------------------") println(model.getNumRound) println(model.getMaxDepth) - model.write.overwrite().save("/tmp/model/") - val loadedModel = XGBoostClassificationModel.load("/tmp/model") - println(loadedModel.getNumRound) - println(loadedModel.getMaxDepth) - model.transform(df).drop(features: _*).show(150, false) +// model.write.overwrite().save("/tmp/model/") +// val loadedModel = XGBoostClassificationModel.load("/tmp/model") +// println(loadedModel.getNumRound) +// println(loadedModel.getMaxDepth) +// model.transform(df).drop(features: _*).show(150, false) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index a48dc987aeb8..1afff94b6df8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -28,15 +28,15 @@ import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker, XGBoostError} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} private[spark] case class RuntimeParams( - numWorkers: Int, - numRounds: Int, - obj: ObjectiveTrait, - eval: EvalTrait, - trackerConf: TrackerConf, - earlyStoppingRounds: Int, - device: String, - isLocal: Boolean, - runOnGpu: Boolean) + numWorkers: Int, + numRounds: Int, + obj: ObjectiveTrait, + eval: EvalTrait, + trackerConf: TrackerConf, + earlyStoppingRounds: Int, + device: String, + isLocal: Boolean, + runOnGpu: Boolean) /** * A trait to manage stage-level scheduling @@ -195,7 +195,11 @@ private[spark] object XGBoost extends StageLevelScheduling { rabitEnv.put("DMLC_TASK_ID", partitionId.toString) try { - Communicator.init(rabitEnv) + try { + Communicator.init(rabitEnv) + } catch { + case e: Throwable => logger.error(e) + } val numEarlyStoppingRounds = runtimeParams.earlyStoppingRounds val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](runtimeParams.numRounds)) @@ -282,7 +286,11 @@ private[spark] object XGBoost extends StageLevelScheduling { logger.error("XGBoost job was aborted due to ", t) throw t } finally { - tracker.stop() + try { + tracker.stop() + } catch { + case t: Throwable => logger.error(t) + } } } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 26bd98966df6..04bbb8fc8df5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -20,7 +20,7 @@ import scala.collection.mutable import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader} import org.apache.spark.ml.xgboost.SparkUtils import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions.{col, udf} @@ -86,16 +86,17 @@ class XGBoostClassifier(override val uid: String, } object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] { - private val uid = Identifiable.randomUID("xgbc") + private val _uid = Identifiable.randomUID("xgbc") + override def load(path: String): XGBoostClassifier = super.load(path) } // TODO add num classes class XGBoostClassificationModel( - uid: String, - model: Booster, - trainingSummary: Option[XGBoostTrainingSummary] = None - ) + uid: String, + model: Booster, + trainingSummary: Option[XGBoostTrainingSummary] = None +) extends XGBoostModel[XGBoostClassificationModel](uid, model, trainingSummary) with ClassificationParams[XGBoostClassificationModel] { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 19c54c2e4a98..cbca99159dba 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -25,7 +25,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.Path import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter} import org.apache.spark.ml.xgboost.SparkUtils import org.apache.spark.rdd.RDD @@ -40,20 +40,34 @@ import ml.dmlc.xgboost4j.scala.spark.params._ /** - * Hold the column indexes used to get the column index + * Hold the column index */ -private case class ColumnIndexes(label: String, - features: String, - weight: Option[String] = None, - baseMargin: Option[String] = None, - group: Option[String] = None, - valiation: Option[String] = None) +private[spark] case class ColumnIndices( + labelId: Int, + featureId: Option[Int], // the feature type is VectorUDT or Array + featureIds: Option[Seq[Int]], // the feature type is columnar + weightId: Option[Int], + marginId: Option[Int], + groupId: Option[Int]) + +private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]] { + + private var dataset: Option[Dataset[_]] = None + + def setEvalDataset(ds: Dataset[_]): T = { + this.dataset = Some(ds) + this.asInstanceOf[T] + } + + def getEvalDataset(): Option[Dataset[_]] = { + this.dataset + } +} private[spark] abstract class XGBoostEstimator[ - Learner <: XGBoostEstimator[Learner, M], - M <: XGBoostModel[M] -] extends Estimator[M] with XGBoostParams[Learner] with SparkParams[Learner] - with ParamMapConversion with DefaultParamsWritable { + Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M] + with XGBoostParams[Learner] with SparkParams[Learner] + with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable { protected val logger = LogFactory.getLog("XGBoostSpark") @@ -64,9 +78,9 @@ private[spark] abstract class XGBoostEstimator[ val serviceLoader = ServiceLoader.load(classOf[XGBoostPlugin], classLoader) - // For now, we only trust GPUXGBoostPlugin. + // For now, we only trust GpuXGBoostPlugin. serviceLoader.asScala.filter(x => x.getClass.getName.equals( - "ml.dmlc.xgboost4j.scala.spark.GPUXGBoostPlugin")).toList match { + "ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin")).toList match { case Nil => None case head :: Nil => Some(head) @@ -96,163 +110,145 @@ private[spark] abstract class XGBoostEstimator[ } /** - * Preprocess the dataset to meet the xgboost input requirement + * Repartition the dataset to the numWorkers if needed. * - * @param dataset - * @return + * @param dataset to be repartition + * @return the repartitioned dataset */ - private def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndexes) = { - // Columns to be selected for XGBoost - val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty - val schema = dataset.schema - - // TODO, support columnar and array. - selectedCols.append(castToFloatIfNeeded(schema, getLabelCol)) - selectedCols.append(col(getFeaturesCol)) - - val weightName = if (isDefined(weightCol) && getWeightCol.nonEmpty) { - selectedCols.append(castToFloatIfNeeded(schema, getWeightCol)) - Some(getWeightCol) + private[spark] def repartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = { + val numPartitions = dataset.rdd.getNumPartitions + if (getForceRepartition || getNumWorkers != numPartitions) { + dataset.repartition(getNumWorkers) } else { - None + dataset } + } - val baseMarginName = if (isDefined(baseMarginCol) && getBaseMarginCol.nonEmpty) { - selectedCols.append(castToFloatIfNeeded(schema, getBaseMarginCol)) - Some(getBaseMarginCol) - } else { - None - } + /** + * Build the columns indices. + */ + private[spark] def buildColumnIndices(schema: StructType): ColumnIndices = { + // Get feature id(s) + val (featureIds: Option[Seq[Int]], featureId: Option[Int]) = + if (getFeaturesCols.length != 0) { + (Some(getFeaturesCols.map(schema.fieldIndex).toSeq), None) + } else { + (None, Some(schema.fieldIndex(getFeaturesCol))) + } - // TODO, check the validation col - val validationName = if (isDefined(validationIndicatorCol) && - getValidationIndicatorCol.nonEmpty) { - selectedCols.append(col(getValidationIndicatorCol)) - Some(getValidationIndicatorCol) - } else { - None + // function to get the column id according to the parameter + def columnId(param: Param[String]): Option[Int] = { + if (isDefined(param) && $(param).nonEmpty) { + Some(schema.fieldIndex($(param))) + } else { + None + } } - var groupName: Option[String] = None - this match { - case p: HasGroupCol => - // Cast group col to IntegerType if necessary - if (isDefined(p.groupCol) && $(p.groupCol).nonEmpty) { - selectedCols.append(castToFloatIfNeeded(schema, p.getGroupCol)) - groupName = Some(p.getGroupCol) - } - case _ => + // Special handle for group + val groupId: Option[Int] = this match { + case p: HasGroupCol => columnId(p.groupCol) + case _ => None } - var input = dataset.select(selectedCols: _*) + ColumnIndices( + labelId = columnId(labelCol).get, + featureId = featureId, + featureIds = featureIds, + columnId(weightCol), + columnId(baseMarginCol), + groupId) + } - // TODO, - // 1. add a parameter to force repartition, - // 2. follow xgboost pyspark way check if repartition is needed. - val numWorkers = getNumWorkers - val numPartitions = dataset.rdd.getNumPartitions - input = if (numWorkers == numPartitions) { - input - } else { - input.repartition(numWorkers) - } - val columnIndexes = ColumnIndexes( - getLabelCol, - getFeaturesCol, - weight = weightName, - baseMargin = baseMarginName, - group = groupName, - valiation = validationName) - (input, columnIndexes) + private[spark] def isDefinedNonEmpty(param: Param[String]): Boolean = { + if (isDefined(param) && $(param).nonEmpty) true else false } /** - * Convert the dataframe to RDD + * Preprocess the dataset to meet the xgboost input requirement * * @param dataset - * @param columnsOrder the order of columns including weight/group/base margin ... - * @return RDD + * @return */ - def toRdd(dataset: Dataset[_], columnIndexes: ColumnIndexes): RDD[Watches] = { + private def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = { + + // Columns to be selected for XGBoost training + val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty + val schema = dataset.schema + + def selectCol(c: Param[String]) = { + if (isDefinedNonEmpty(c)) { + // Validation col should be a boolean column. + if (c == featuresCol) { + selectedCols.append(col($(c))) + } else { + selectedCols.append(castToFloatIfNeeded(schema, $(c))) + } + } + } - // 1. to XGBLabeledPoint - val labeledPointRDD = dataset.rdd.map { + Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(selectCol) + this match { + case p: HasGroupCol => selectCol(p.groupCol) + case _ => + } + val input = repartitionIfNeeded(dataset.select(selectedCols: _*)) + + val columnIndices = buildColumnIndices(input.schema) + (input, columnIndices) + } + + private def toXGBLabeledPoint(dataset: Dataset[_], + columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = { + dataset.rdd.map { case row: Row => - val label = row.getFloat(row.fieldIndex(columnIndexes.label)) - val features = row.getAs[Vector](columnIndexes.features) - val weight = columnIndexes.weight.map(v => row.getFloat(row.fieldIndex(v))).getOrElse(1.0f) - val baseMargin = columnIndexes.baseMargin.map(v => - row.getFloat(row.fieldIndex(v))).getOrElse(Float.NaN) - val group = columnIndexes.group.map(v => - row.getFloat(row.fieldIndex(v))).getOrElse(-1.0f) + val label = row.getFloat(columnIndexes.labelId) + val features = row.getAs[Vector](columnIndexes.featureId.get) + val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f) + val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN) + val group = columnIndexes.groupId.map(row.getFloat).getOrElse(-1.0f) // TODO support sparse vector. // TODO support array val values = features.toArray.map(_.toFloat) - val isValidation = columnIndexes.valiation.exists(v => - row.getBoolean(row.fieldIndex(v))) - - (isValidation, - XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin)) + XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin) } + } - - labeledPointRDD.mapPartitions { iter => - val datasets: ArrayBuffer[DMatrix] = ArrayBuffer.empty - val names: ArrayBuffer[String] = ArrayBuffer.empty - val validations: ArrayBuffer[XGBLabeledPoint] = ArrayBuffer.empty - - val trainIter = if (columnIndexes.valiation.isDefined) { - // Extract validations during build Train DMatrix - val dataIter = new Iterator[XGBLabeledPoint] { - private var tmp: Option[XGBLabeledPoint] = None - - override def hasNext: Boolean = { - if (tmp.isDefined) { - return true - } - while (iter.hasNext) { - val (isVal, labelPoint) = iter.next() - if (isVal) { - validations.append(labelPoint) - } else { - tmp = Some(labelPoint) - return true - } - } - false - } - - override def next(): XGBLabeledPoint = { - val xgbLabeledPoint = tmp.get - tmp = None - xgbLabeledPoint - } - } - dataIter - } else { - iter.map(_._2) + /** + * Convert the dataframe to RDD + * + * @param dataset + * @param columnsOrder the order of columns including weight/group/base margin ... + * @return RDD + */ + def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = { + val trainRDD = toXGBLabeledPoint(dataset, columnIndices) + + val x = getEvalDataset() + getEvalDataset().map { eval => + val (evalDf, _) = preprocess(eval) + val evalRDD = toXGBLabeledPoint(evalDf, columnIndices) + trainRDD.zipPartitions(evalRDD) { (trainIter, evalIter) => + val trainDMatrix = new DMatrix(trainIter) + val evalDMatrix = new DMatrix(evalIter) + val watches = new Watches(Array(trainDMatrix, evalDMatrix), + Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) + Iterator.single(watches) } - - datasets.append(new DMatrix(trainIter)) - names.append(Utils.TRAIN_NAME) - if (columnIndexes.valiation.isDefined) { - datasets.append(new DMatrix(validations.toIterator)) - names.append(Utils.VALIDATION_NAME) + }.getOrElse( + trainRDD.mapPartitions { iter => + // Handle weight/base margin + val watches = new Watches(Array(new DMatrix(iter)), Array(Utils.TRAIN_NAME), None) + Iterator.single(watches) } - - // TODO 1. support external memory 2. rework or remove Watches - val watches = new Watches(datasets.toArray, names.toArray, None) - Iterator.single(watches) - } + ) } protected def createModel(booster: Booster, summary: XGBoostTrainingSummary): M private def getRuntimeParameters(isLocal: Boolean): RuntimeParams = { - - val runOnGpu = false - + val runOnGpu = if (getDevice != "cpu" || getTreeMethod == "gpu_hist") true else false RuntimeParams( getNumWorkers, getNumRound, @@ -361,9 +357,9 @@ private[spark] abstract class XGBoostEstimator[ * @tparam the exact model which must extend from XGBoostModel */ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]]( - override val uid: String, - private val model: Booster, - private val trainingSummary: Option[XGBoostTrainingSummary]) extends Model[M] with MLWritable + override val uid: String, + private val model: Booster, + private val trainingSummary: Option[XGBoostTrainingSummary]) extends Model[M] with MLWritable with XGBoostParams[M] with SparkParams[M] { protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col" @@ -395,17 +391,19 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]]( // Be careful about the order of columns var schema = dataset.schema - var hasLeafPredictionCol = false - if (isDefined(leafPredictionCol) && getLeafPredictionCol.nonEmpty) { - schema = schema.add(StructField(getLeafPredictionCol, ArrayType(FloatType))) - hasLeafPredictionCol = true + /** If the parameter is defined, add it to schema and turn true */ + def addToSchema(param: Param[String], colName: Option[String] = None): Boolean = { + if (isDefined(param) && $(param).nonEmpty) { + val name = colName.getOrElse($(param)) + schema = schema.add(StructField(name, ArrayType(FloatType))) + true + } else { + false + } } - var hasContribPredictionCol = false - if (isDefined(contribPredictionCol) && getContribPredictionCol.nonEmpty) { - schema = schema.add(StructField(getContribPredictionCol, ArrayType(FloatType))) - hasContribPredictionCol = true - } + val hasLeafPredictionCol = addToSchema(leafPredictionCol) + val hasContribPredictionCol = addToSchema(contribPredictionCol) var hasRawPredictionCol = false // For classification case, the tranformed col is probability, @@ -413,16 +411,8 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]]( var hasTransformedCol = false this match { case p: ClassificationParams[_] => // classification case - if (isDefined(p.rawPredictionCol) && p.getRawPredictionCol.nonEmpty) { - schema = schema.add( - StructField(p.getRawPredictionCol, ArrayType(FloatType))) - hasRawPredictionCol = true - } - if (isDefined(p.probabilityCol) && p.getProbabilityCol.nonEmpty) { - schema = schema.add( - StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType))) - hasTransformedCol = true - } + hasRawPredictionCol = addToSchema(p.rawPredictionCol) + hasTransformedCol = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL)) if (isDefined(predictionCol) && getPredictionCol.nonEmpty) { // Let's use transformed col to calculate the prediction @@ -435,11 +425,8 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]]( } case _ => // Rename TMP_TRANSFORMED_COL to prediction in the postTransform. - if (isDefined(predictionCol) && getPredictionCol.nonEmpty) { - schema = schema.add( - StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType))) - hasTransformedCol = true - } + hasTransformedCol = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL)) + } // TODO configurable @@ -457,25 +444,29 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]]( // DMatrix used to prediction val dm = new DMatrix(features.map(_.asXGB)) - var tmpOut = batchRow.map(_.toSeq) + try { + var tmpOut = batchRow.map(_.toSeq) - val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map { - case (a, b) => a ++ Seq(b) - } + val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map { + case (a, b) => a ++ Seq(b) + } - if (hasLeafPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predictLeaf(dm)) - } - if (hasContribPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predictContrib(dm)) - } - if (hasRawPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = true)) - } - if (hasTransformedCol) { - tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = false)) + if (hasLeafPredictionCol) { + tmpOut = zip(tmpOut, bBooster.value.predictLeaf(dm)) + } + if (hasContribPredictionCol) { + tmpOut = zip(tmpOut, bBooster.value.predictContrib(dm)) + } + if (hasRawPredictionCol) { + tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = true)) + } + if (hasTransformedCol) { + tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = false)) + } + tmpOut.map(Row.fromSeq) + } finally { + dm.delete() } - tmpOut.map(Row.fromSeq) } }(Encoders.row(schema)) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala index 00d805d626bb..e43fa0b3bbca 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala @@ -38,7 +38,7 @@ trait XGBoostPlugin extends Serializable { * @return RDD[Watches] */ def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( - estimator: XGBoostEstimator[T, M], - dataset: Dataset[_]): RDD[Watches] + estimator: XGBoostEstimator[T, M], + dataset: Dataset[_]): RDD[Watches] } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala index 776ade43ffb0..f976cad937e5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala @@ -94,15 +94,6 @@ trait HasFeaturesCols extends Params { } } -trait HasValidationIndicatorCol extends Params { - - final val validationIndicatorCol: Param[String] = new Param[String](this, - "validationIndicatorCol", "Name of the column that indicates whether each row is for " + - "training or for validation. False indicates training; true indicates validation.") - - final def getValidationIndicatorCol: String = $(validationIndicatorCol) -} - /** * A trait to hold non-xgboost parameters */ @@ -124,7 +115,7 @@ trait NonXGBoostParams extends Params { */ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFeaturesCol with HasLabelCol with HasBaseMarginCol with HasWeightCol with HasPredictionCol - with HasLeafPredictionCol with HasContribPredictionCol with HasValidationIndicatorCol + with HasLeafPredictionCol with HasContribPredictionCol with RabitParams with NonXGBoostParams with SchemaValidationTrait { final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost", @@ -132,6 +123,12 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe final def getNumRound: Int = $(numRound) + final val forceRepartition = new BooleanParam(this, "forceRepartition", "If the partition " + + "is equal to numWorkers, xgboost won't repartition the dataset. Set forceRepartition to " + + "true to force repartition.") + + final def getForceRepartition: Boolean = $(forceRepartition) + final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting", ParamValidators.gtEq(1)) @@ -139,6 +136,8 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe "Number of rounds of decreasing eval metric to tolerate before stopping training", ParamValidators.gtEq(0)) + final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds) + final val inferBatchSize = new IntParam(this, "inferBatchSize", "batch size in rows " + "to be grouped for inference", ParamValidators.gtEq(1)) @@ -146,19 +145,27 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe /** @group getParam */ final def getInferBatchSize: Int = $(inferBatchSize) - final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds) + /** + * the value treated as missing. default: Float.NaN + */ + final val missing = new FloatParam(this, "missing", "The value treated as missing") + + final def getMissing: Float = $(missing) setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10), - numEarlyStoppingRounds -> 0) + numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN, + featuresCols -> Array.empty) addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol, labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol, - validationIndicatorCol) + forceRepartition, missing, featuresCols) final def getNumWorkers: Int = $(numWorkers) def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T] + def setForceRepartition(value: Boolean): T = set(forceRepartition, value).asInstanceOf[T] + def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T] def setFeaturesCol(value: String): T = set(featuresCol, value).asInstanceOf[T] @@ -179,9 +186,6 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T] - def setValidationIndicatorCol(value: String): T = - set(validationIndicatorCol, value).asInstanceOf[T] - def setRabitTrackerTimeout(value: Int): T = set(rabitTrackerTimeout, value).asInstanceOf[T] def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T] @@ -210,9 +214,11 @@ private[spark] trait ClassificationParams[T <: Params] extends SparkParams[T] def setThresholds(value: Array[Double]): T = set(thresholds, value).asInstanceOf[T] + /** + * XGBoost doesn't use validateAndTransformSchema. + */ override def validateAndTransformSchema(schema: StructType, fitting: Boolean): StructType = { - var outputSchema = SparkUtils.appendColumn(schema, $(predictionCol), DoubleType) outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(rawPredictionCol)) outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(probabilityCol)) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 57639aaebb4d..1cba5c672e9b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -57,19 +57,19 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS // df = df.withColumn("base_margin", lit(20)) // .withColumn("weight", rand(1)) - var Array(trainDf, validationDf) = df.randomSplit(Array(0.8, 0.2), seed = 1) - - trainDf = trainDf.withColumn("validation", lit(false)) - validationDf = validationDf.withColumn("validationDf", lit(true)) - - df = trainDf.union(validationDf) - - // Assemble the feature columns into a single vector column + // Assemble the feature columns into a single vector column val assembler = new VectorAssembler() .setInputCols(features) .setOutputCol("features") val dataset = assembler.transform(df) + var Array(trainDf, validationDf) = dataset.randomSplit(Array(0.8, 0.2), seed = 1) + +// trainDf = trainDf.withColumn("validation", lit(false)) +// validationDf = validationDf.withColumn("validationDf", lit(true)) + +// df = trainDf.union(validationDf) + // val arrayInput = df.select(array(features.map(col(_)): _*).as("features"), // col("label"), col("base_margin")) @@ -80,7 +80,8 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS // .setWeightCol("weight") // .setBaseMarginCol("base_margin") .setLabelCol(labelCol) - .setValidationIndicatorCol("validation") + .setEvalDataset(validationDf) +// .setValidationIndicatorCol("validation") // .setPredictionCol("") .setRawPredictionCol("") .setProbabilityCol("xxxx") @@ -93,7 +94,7 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS println(loadedEst.getNumRound) println(loadedEst.getMaxDepth) - val model = loadedEst.fit(dataset) + val model = est.fit(dataset) println("-----------------------") println(model.getNumRound) println(model.getMaxDepth)