diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/GpuColumnBatch.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/GpuColumnBatch.java index 13ef14d76e14..0e5bbb009ca3 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/GpuColumnBatch.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/GpuColumnBatch.java @@ -43,15 +43,15 @@ public void close() { } } - public Table slice(int index) { + public Table select(int index) { if (index < 0) { return null; } - return slice(Arrays.asList(index)); + return select(Arrays.asList(index)); } /** Slice the columns indicated by indices into a Table*/ - public Table slice(List indices) { + public Table select(List indices) { if (indices == null || indices.size() == 0) { return null; } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index 68313d436385..7eff5794dc81 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -17,16 +17,22 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{asScalaIteratorConverter, seqAsJavaListConverter} import ai.rapids.cudf.Table -import com.nvidia.spark.rapids.ColumnarRdd +import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils} +import org.apache.commons.logging.LogFactory +import org.apache.spark.TaskContext +import org.apache.spark.ml.functions.array_to_vector import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.vectorized.ColumnarBatch import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch} -import ml.dmlc.xgboost4j.scala.QuantileDMatrix +import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix} import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol /** @@ -35,6 +41,8 @@ import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol */ class GpuXGBoostPlugin extends XGBoostPlugin { + private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin") + /** * Whether the plugin is enabled or not, if not enabled, fallback * to the regular CPU pipeline @@ -115,10 +123,10 @@ class GpuXGBoostPlugin extends XGBoostPlugin { val colBatchIter = iter.map { table => withResource(new GpuColumnBatch(table, null)) { batch => new CudfColumnBatch( - batch.slice(indices.featureIds.get.map(Integer.valueOf).asJava), - batch.slice(indices.labelId), - batch.slice(indices.weightId.getOrElse(-1)), - batch.slice(indices.marginId.getOrElse(-1))); + batch.select(indices.featureIds.get.map(Integer.valueOf).asJava), + batch.select(indices.labelId), + batch.select(indices.weightId.getOrElse(-1)), + batch.select(indices.marginId.getOrElse(-1))); } } new QuantileDMatrix(colBatchIter, missing, maxBin, nthread) @@ -150,4 +158,124 @@ class GpuXGBoostPlugin extends XGBoostPlugin { } } + + override def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], + dataset: Dataset[_]): DataFrame = { + val sc = dataset.sparkSession.sparkContext + + val (transformedSchema, pred) = model.preprocess(dataset) + val bBooster = sc.broadcast(model.nativeBooster) + val bOriginalSchema = sc.broadcast(dataset.schema) + + val featureIds = model.getFeaturesCols.distinct.map(dataset.schema.fieldIndex).toList + val isLocal = sc.isLocal + val missing = model.getMissing + val nThread = model.getNthread + + val rdd = ColumnarRdd(dataset.asInstanceOf[DataFrame]).mapPartitions { tableIters => + // booster is visible for all spark tasks in the same executor + val booster = bBooster.value + val originalSchema = bOriginalSchema.value + + // UnsafeProjection is not serializable so do it on the executor side + val toUnsafe = UnsafeProjection.create(originalSchema) + + synchronized { + val device = booster.getAttr("device") + if (device != null && device.trim.isEmpty) { + booster.setAttr("device", "cuda") + val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0 + booster.setParam("device", s"cuda:$gpuId") + logger.info("GPU transform on GPU device: " + gpuId) + } + } + + // Iterator on Row + new Iterator[Row] { + // Convert InternalRow to Row + private val converter: InternalRow => Row = CatalystTypeConverters + .createToScalaConverter(originalSchema) + .asInstanceOf[InternalRow => Row] + + // GPU batches read in must be closed by the receiver + @transient var currentBatch: ColumnarBatch = null + + // Iterator on Row + var iter: Iterator[Row] = null + + TaskContext.get().addTaskCompletionListener[Unit](_ => { + closeCurrentBatch() // close the last ColumnarBatch + }) + + private def closeCurrentBatch(): Unit = { + if (currentBatch != null) { + currentBatch.close() + currentBatch = null + } + } + + def loadNextBatch(): Unit = { + closeCurrentBatch() + if (tableIters.hasNext) { + val dataTypes = originalSchema.fields.map(x => x.dataType) + iter = withResource(tableIters.next()) { table => + val gpuColumnBatch = new GpuColumnBatch(table, originalSchema) + // Create DMatrix + val featureTable = gpuColumnBatch.select(featureIds.map(Integer.valueOf).asJava) + if (featureTable == null) { + throw new RuntimeException("Something wrong for feature indices") + } + try { + val cudfColumnBatch = new CudfColumnBatch(featureTable, null, null, null) + val dm = new DMatrix(cudfColumnBatch, missing, nThread) + if (dm == null) { + Iterator.empty + } else { + try { + currentBatch = new ColumnarBatch( + GpuColumnVectorUtils.extractHostColumns(table, dataTypes), + table.getRowCount().toInt) + val rowIterator = currentBatch.rowIterator().asScala.map(toUnsafe) + .map(converter(_)) + model.predictInternal(booster, dm, pred, rowIterator).toIterator + } finally { + dm.delete() + } + } + } finally { + featureTable.close() + } + } + } else { + iter = null + } + } + + override def hasNext: Boolean = { + val itHasNext = iter != null && iter.hasNext + if (!itHasNext) { // Don't have extra Row for current ColumnarBatch + loadNextBatch() + iter != null && iter.hasNext + } else { + itHasNext + } + } + + override def next(): Row = { + if (iter == null || !iter.hasNext) { + loadNextBatch() + } + if (iter == null) { + throw new NoSuchElementException() + } + iter.next() + } + } + } + bBooster.unpersist(false) + bOriginalSchema.unpersist(false) + + val output = dataset.sparkSession.createDataFrame(rdd, transformedSchema) + model.postTransform(output, pred).toDF() + } } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.test.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.test.parquet new file mode 100644 index 000000000000..5897b6fadb2b Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.test.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.train.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.train.parquet new file mode 100644 index 000000000000..780efdc13d36 Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/binary.train.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.test.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.test.parquet new file mode 100644 index 000000000000..b8347280f993 Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.test.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.train.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.train.parquet new file mode 100644 index 000000000000..066f31b0ffa3 Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/multiclass.train.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.test.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.test.parquet new file mode 100644 index 000000000000..64036d134387 Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.test.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.train.parquet b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.train.parquet new file mode 100644 index 000000000000..ecad3d4fd0ca Binary files /dev/null and b/jvm-packages/xgboost4j-spark-gpu/src/test/resources/regression.train.parquet differ diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuTestSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuTestSuite.scala index 9dbe8ee5d935..2f5fea3eec36 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuTestSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuTestSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2023 by Contributors + Copyright (c) 2021-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala index dcb3321d96cc..f8561fe88c5c 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala @@ -1,14 +1,33 @@ +/* + Copyright (c) 2024 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package ml.dmlc.xgboost4j.scala.spark +import java.io.File + import scala.collection.mutable.ArrayBuffer +import ai.rapids.cudf.{CSVOptions, DType, Schema, Table} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.{FloatType, StructField, StructType} import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite class GpuXGBoostPluginSuite extends GpuTestSuite { - test("isEnabled") { def checkIsEnabled(spark: SparkSession, expected: Boolean): Unit = { import spark.implicits._ @@ -37,7 +56,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f), (3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f), (4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f), - (5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f), + (5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f) ).toDF("c1", "c2", "weight", "margin", "label", "other") val classifier = new XGBoostClassifier() @@ -64,7 +83,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f), (3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f), (4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f), - (5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f), + (5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f) ).toDF("c1", "c2", "weight", "margin", "label", "other") .repartition(5) @@ -114,7 +133,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f), (3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f), (4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f), - (5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f), + (5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f) ).toDF("c1", "c2", "weight", "margin", "label", "other") val features = Array("c1", "c2") @@ -168,7 +187,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { val train = Seq( (1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f), - (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f), + (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f) ).toDF("c1", "c2", "weight", "margin", "label", "other") // dataPoint -> (missing, rowNum, nonMissing) @@ -179,7 +198,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { (2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f), (3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f), (4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f), - (5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f), + (5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f) ).toDF("c1", "c2", "weight", "margin", "label", "other") val features = Array("c1", "c2") @@ -226,4 +245,19 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { } } } + + + test("XGBoost-Spark should match xgboost4j") { + withGpuSparkSession() { spark => + + val cols = Array("c0", "c1", "c2", "c3", "c4", "c5") + val label = "label" + + val table = Table.readParquet(new File(getResourcePath("/binary.train.parquet"))) + val df = spark.read.parquet(getResourcePath("/binary.train.parquet")) + + + df.show() + } + } } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala index 215fc81f3303..a8ba1c1b225a 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala @@ -16,7 +16,6 @@ package ml.dmlc.xgboost4j.scala.spark -import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite @@ -41,53 +40,55 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { var Array(trainDf, validationDf) = df.randomSplit(Array(0.8, 0.2), seed = 1) -// trainDf = trainDf.withColumn("validation", lit(false)) -// validationDf = validationDf.withColumn("validationDf", lit(true)) + // trainDf = trainDf.withColumn("validation", lit(false)) + // validationDf = validationDf.withColumn("validationDf", lit(true)) -// df = trainDf.union(validationDf) -// -// // Assemble the feature columns into a single vector column -// val assembler = new VectorAssembler() -// .setInputCols(features) -// .setOutputCol("features") -// val dataset = assembler.transform(df) + // df = trainDf.union(validationDf) + // + // // Assemble the feature columns into a single vector column + // val assembler = new VectorAssembler() + // .setInputCols(features) + // .setOutputCol("features") + // val dataset = assembler.transform(df) // val arrayInput = df.select(array(features.map(col(_)): _*).as("features"), // col("label"), col("base_margin")) val est = new XGBoostClassifier() .setNumWorkers(1) - .setNumRound(2) - .setMaxDepth(3) + .setNumRound(100) + // .setMaxDepth(3) // .setWeightCol("weight") // .setBaseMarginCol("base_margin") .setFeaturesCol(features) .setLabelCol(labelCol) + .setLeafPredictionCol("leaf") + .setContribPredictionCol("contrib") .setDevice("cuda") - .setEvalDataset(validationDf) -// .setValidationIndicatorCol("validation") - // .setPredictionCol("") - .setRawPredictionCol("") - .setProbabilityCol("xxxx") + // .setEvalDataset(validationDf) + // .setValidationIndicatorCol("validation") + // .setPredictionCol("") + // .setRawPredictionCol("") + // .setProbabilityCol("xxxx") // .setContribPredictionCol("contrb") // .setLeafPredictionCol("leaf") // val est = new XGBoostClassifier().setLabelCol(labelCol) // est.fit(arrayInput) - est.write.overwrite().save("/tmp/abcdef") - val loadedEst = XGBoostClassifier.load("/tmp/abcdef") - println(loadedEst.getNumRound) - println(loadedEst.getMaxDepth) + // est.write.overwrite().save("/tmp/abcdef") + // val loadedEst = XGBoostClassifier.load("/tmp/abcdef") + // println(loadedEst.getNumRound) + // println(loadedEst.getMaxDepth) val model = est.fit(trainDf) - println("-----------------------") - println(model.getNumRound) - println(model.getMaxDepth) - -// model.write.overwrite().save("/tmp/model/") -// val loadedModel = XGBoostClassificationModel.load("/tmp/model") -// println(loadedModel.getNumRound) -// println(loadedModel.getMaxDepth) -// model.transform(df).drop(features: _*).show(150, false) + + val out = model.transform(df) + out.printSchema() + out.show(150, false) + // model.write.overwrite().save("/tmp/model/") + // val loadedModel = XGBoostClassificationModel.load("/tmp/model") + // println(loadedModel.getNumRound) + // println(loadedModel.getMaxDepth) + // model.transform(df).drop(features: _*).show(150, false) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 2ccae01e5b06..cb734d32a8f6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader} import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams} -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.json4s.DefaultFormats @@ -128,23 +128,24 @@ class XGBoostClassificationModel( def this(uid: String) = this(uid, 0, null) - override def postTransform(dataset: Dataset[_]): Dataset[_] = { - var output = dataset + override protected[spark] def postTransform(dataset: Dataset[_], + pred: PredictedColumns): Dataset[_] = { + var output = super.postTransform(dataset, pred) // Always use probability col to get the prediction - if (isDefinedNonEmpty(predictionCol)) { + if (isDefinedNonEmpty(predictionCol) && pred.predTmp) { val predCol = udf { probability: mutable.WrappedArray[Float] => probability2prediction(Vectors.dense(probability.map(_.toDouble).toArray)) } output = output.withColumn(getPredictionCol, predCol(col(TMP_TRANSFORMED_COL))) } - if (isDefinedNonEmpty(probabilityCol)) { + if (isDefinedNonEmpty(probabilityCol) && pred.predTmp) { output = output.withColumn(TMP_TRANSFORMED_COL, array_to_vector(output.col(TMP_TRANSFORMED_COL))) .withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol) } - if (isDefinedNonEmpty(rawPredictionCol)) { + if (pred.predRaw) { output = output.withColumn(getRawPredictionCol, array_to_vector(output.col(getRawPredictionCol))) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 2a57a65c45a5..810be139f2c7 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -67,13 +67,7 @@ private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoost } } -private[spark] trait XGBoostEstimator[ - Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M] - with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner] - with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable { - - protected val logger = LogFactory.getLog("XGBoostSpark") - +private[spark] trait PluginMixin { // Find the XGBoostPlugin by ServiceLoader private val plugin: Option[XGBoostPlugin] = { val classLoader = Option(Thread.currentThread().getContextClassLoader) @@ -92,11 +86,20 @@ private[spark] trait XGBoostEstimator[ } /** Visiable for testing */ - private[spark] def getPlugin: Option[XGBoostPlugin] = plugin + protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin - private def isPluginEnabled(dataset: Dataset[_]): Boolean = { + protected def isPluginEnabled(dataset: Dataset[_]): Boolean = { plugin.map(_.isEnabled(dataset)).getOrElse(false) } +} + +private[spark] trait XGBoostEstimator[ + Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M] + with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner] + with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable + with PluginMixin { + + protected val logger = LogFactory.getLog("XGBoostSpark") /** * Pre-convert input double data to floats to align with XGBoost's internal float-based @@ -383,7 +386,7 @@ private[spark] trait XGBoostEstimator[ validate(dataset) val rdd = if (isPluginEnabled(dataset)) { - plugin.get.buildRddWatches(this, dataset) + getPlugin.get.buildRddWatches(this, dataset) } else { val (input, columnIndexes) = preprocess(dataset) toRdd(input, columnIndexes) @@ -407,6 +410,13 @@ private[spark] trait XGBoostEstimator[ } } +/** Indicate what to be predicted */ +private[spark] case class PredictedColumns( + predLeaf: Boolean, + predContrib: Boolean, + predRaw: Boolean, + predTmp: Boolean) + /** * XGBoost base model * @@ -416,7 +426,7 @@ private[spark] trait XGBoostEstimator[ * @tparam the exact model which must extend from XGBoostModel */ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable - with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] { + with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin { protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col" @@ -436,12 +446,27 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML validateAndTransformSchema(schema, false) } - def postTransform(dataset: Dataset[_]): Dataset[_] = dataset - - override def transform(dataset: Dataset[_]): DataFrame = { + protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = { + var output = dataset + // Convert leaf/contrib to the vector from array + if (pred.predLeaf) { + output = output.withColumn(getLeafPredictionCol, + array_to_vector(output.col(getLeafPredictionCol))) + } - val spark = dataset.sparkSession + if (pred.predContrib) { + output = output.withColumn(getContribPredictionCol, + array_to_vector(output.col(getContribPredictionCol))) + } + output + } + /** + * Preprocess the schema before transforming. + * + * @return the transformed schema and the + */ + private[spark] def preprocess(dataset: Dataset[_]): (StructType, PredictedColumns) = { // Be careful about the order of columns var schema = dataset.schema @@ -456,68 +481,77 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML } } - val hasLeafPredictionCol = addToSchema(leafPredictionCol) - val hasContribPredictionCol = addToSchema(contribPredictionCol) + val predLeaf = addToSchema(leafPredictionCol) + val predContrib = addToSchema(contribPredictionCol) - var hasRawPredictionCol = false + var predRaw = false // For classification case, the tranformed col is probability, // while for others, it's the prediction value. - var hasTransformedCol = false + var predTmp = false this match { case p: XGBProbabilisticClassifierParams[_] => // classification case - hasRawPredictionCol = addToSchema(p.rawPredictionCol) - hasTransformedCol = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL)) + predRaw = addToSchema(p.rawPredictionCol) + predTmp = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL)) if (isDefinedNonEmpty(predictionCol)) { // Let's use transformed col to calculate the prediction - if (!hasTransformedCol) { + if (!predTmp) { // Add the transformed col for predition schema = schema.add( StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType))) - hasTransformedCol = true + predTmp = true } } case _ => // Rename TMP_TRANSFORMED_COL to prediction in the postTransform. - hasTransformedCol = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL)) + predTmp = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL)) + } + (schema, PredictedColumns(predLeaf, predContrib, predRaw, predTmp)) + } + /** Predict */ + private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns, + batchRow: Iterator[Row]): Seq[Row] = { + var tmpOut = batchRow.toSeq.map(_.toSeq) + val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map { + case (a, b) => a ++ Seq(b) } + if (pred.predLeaf) { + tmpOut = zip(tmpOut, booster.predictLeaf(dm)) + } + if (pred.predContrib) { + tmpOut = zip(tmpOut, booster.predictContrib(dm)) + } + if (pred.predRaw) { + tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true)) + } + if (pred.predTmp) { + tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false)) + } + tmpOut.map(Row.fromSeq) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + if (getPlugin.isDefined) { + return getPlugin.get.transform(this, dataset) + } + + val (schema, pred) = preprocess(dataset) + val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster) // TODO configurable val inferBatchSize = 32 << 10 // Broadcast the booster to each executor. - val bBooster = spark.sparkContext.broadcast(nativeBooster) val featureName = getFeaturesCol - var output = dataset.toDF().mapPartitions { rowIter => - + val output = dataset.toDF().mapPartitions { rowIter => rowIter.grouped(inferBatchSize).flatMap { batchRow => val features = batchRow.iterator.map(row => row.getAs[Vector]( row.fieldIndex(featureName))) - // DMatrix used to prediction val dm = new DMatrix(features.map(_.asXGB)) - try { - var tmpOut = batchRow.map(_.toSeq) - - val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map { - case (a, b) => a ++ Seq(b) - } - - if (hasLeafPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predictLeaf(dm)) - } - if (hasContribPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predictContrib(dm)) - } - if (hasRawPredictionCol) { - tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = true)) - } - if (hasTransformedCol) { - tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = false)) - } - tmpOut.map(Row.fromSeq) + predictInternal(bBooster.value, dm, pred, batchRow.toIterator) } finally { dm.delete() } @@ -525,19 +559,7 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML }(Encoders.row(schema)) bBooster.unpersist(blocking = false) - - // Convert leaf/contrib to the vector from array - if (hasLeafPredictionCol) { - output = output.withColumn(getLeafPredictionCol, - array_to_vector(output.col(getLeafPredictionCol))) - } - - if (hasContribPredictionCol) { - output = output.withColumn(getContribPredictionCol, - array_to_vector(output.col(getContribPredictionCol))) - } - - postTransform(output).toDF() + postTransform(output, pred).toDF() } override def write: MLWriter = new XGBoostModelWriter[XGBoostModel[_]](this) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala index e43fa0b3bbca..3e18b6439988 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.Serializable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} trait XGBoostPlugin extends Serializable { /** @@ -41,4 +41,7 @@ trait XGBoostPlugin extends Serializable { estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): RDD[Watches] + + def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], dataset: Dataset[_]): DataFrame + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index df4bf081f4f7..e77e23da4ea2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -108,7 +108,7 @@ trait PerTest extends BeforeAndAfterEach { (0.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)), (1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)), (0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)), - (1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)), + (1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)) ))).toDF("label", "margin", "weight", "features") def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq( @@ -117,7 +117,7 @@ trait PerTest extends BeforeAndAfterEach { (2.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)), (1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)), (0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)), - (2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)), + (2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)) ))).toDF("label", "margin", "weight", "features") @@ -127,7 +127,7 @@ trait PerTest extends BeforeAndAfterEach { (2.0, 1, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)), (1.0, 0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)), (0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)), - (2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)), + (2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)) ))).toDF("label", "group", "margin", "weight", "features") } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index c2f583d1250a..484673922d74 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -22,8 +22,8 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.linalg.Vectors -import org.json4s.jackson.parseJson import org.json4s.{DefaultFormats, Formats} +import org.json4s.jackson.parseJson import org.scalatest.funsuite.AnyFunSuite import ml.dmlc.xgboost4j.scala.DMatrix @@ -145,7 +145,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu val dataset = ss.createDataFrame(sc.parallelize(Seq( (1.0, 0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0), "a"), (0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0), "b"), - (2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7), "c"), + (2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier() @@ -187,7 +187,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"), (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"), (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"), - (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c"), + (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier() @@ -224,7 +224,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"), (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"), (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"), - (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c"), + (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier() @@ -260,7 +260,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"), (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"), (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"), - (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c"), + (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier() @@ -300,7 +300,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"), (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"), (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"), - (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c"), + (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier() @@ -357,7 +357,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (11.0, 0, 0.15, 11.0, Vectors.dense(trainData(0)), "a"), (12.0, 12, -0.15, 10.0, Vectors.dense(trainData(1)).toSparse, "b"), (13.0, 12, -0.15, 10.0, Vectors.dense(trainData(2)), "b"), - (14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c"), + (14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val evalData = Array( Array(1.0, 2.0, 3.0, 4.0, 5.0), @@ -369,7 +369,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu (1.0, 0, 0.5, 1.0, Vectors.dense(evalData(0)), "a"), (2.0, 2, -0.5, 0.0, Vectors.dense(evalData(1)).toSparse, "b"), (3.0, 2, -0.5, 0.0, Vectors.dense(evalData(2)), "b"), - (4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c"), + (4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c") ))).toDF("label", "group", "margin", "weight", "features", "other") val classifier = new XGBoostClassifier()