From 1a9dd485c15a960598c24c6a5bcd9588258ac718 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 20 Jun 2024 20:02:57 +0800 Subject: [PATCH] Support base margin and add more tests (#12) --- .../scala/spark/XGBoostEstimator.scala | 31 +++- .../scala/spark/XGBoostClassifierSuite.scala | 88 +++++++++++- .../scala/spark/XGBoostEstimatorSuite.scala | 136 +++++++++++++++++- 3 files changed, 241 insertions(+), 14 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index c0d367c58b44..ebfadc4da63e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import java.util.ServiceLoader +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters.iterableAsScalaIterableConverter @@ -258,20 +259,38 @@ private[spark] abstract class XGBoostEstimator[ private[spark] def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = { val trainRDD = toXGBLabeledPoint(dataset, columnIndices) + // transform the labeledpoint to get margins and build DMatrix + // TODO support basemargin for multiclassification + // TODO, move it into JNI + def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = { + if (columnIndices.marginId.isDefined) { + val trainMargins = new mutable.ArrayBuilder.ofFloat + val transformedIter = iter.map { labeledPoint => + trainMargins += labeledPoint.baseMargin + labeledPoint + } + val dm = new DMatrix(transformedIter) + dm.setBaseMargin(trainMargins.result()) + dm + } else { + new DMatrix(iter) + } + } + getEvalDataset().map { eval => val (evalDf, _) = preprocess(eval) val evalRDD = toXGBLabeledPoint(evalDf, columnIndices) - trainRDD.zipPartitions(evalRDD) { (trainIter, evalIter) => - val trainDMatrix = new DMatrix(trainIter) - val evalDMatrix = new DMatrix(evalIter) + trainRDD.zipPartitions(evalRDD) { (left, right) => + val trainDMatrix = buildDMatrix(left) + val evalDMatrix = buildDMatrix(right) val watches = new Watches(Array(trainDMatrix, evalDMatrix), Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) Iterator.single(watches) } }.getOrElse( trainRDD.mapPartitions { iter => - // Handle weight/base margin - val watches = new Watches(Array(new DMatrix(iter)), Array(Utils.TRAIN_NAME), None) + val dm = buildDMatrix(iter) + val watches = new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) Iterator.single(watches) } ) @@ -371,7 +390,7 @@ private[spark] abstract class XGBoostEstimator[ copyValues(createModel(booster, summary)) } - override def copy(extra: ParamMap): Learner = defaultCopy(extra) + override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner] // Not used in XGBoost override def transformSchema(schema: StructType): StructType = { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 1cba5c672e9b..aabb2c57adf5 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -16,12 +16,88 @@ package ml.dmlc.xgboost4j.scala.spark +import java.io.File + import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.sql.functions.lit +import org.apache.spark.ml.param.ParamMap import org.scalatest.funsuite.AnyFunSuite +import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams + class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite { + test("params") { + val xgbParams: Map[String, Any] = Map( + "max_depth" -> 5, + "eta" -> 0.2, + "objective" -> "binary:logistic" + ) + val classifier = new XGBoostClassifier(xgbParams) + .setFeaturesCol("abc") + .setMissing(0.2f) + .setAlpha(0.97) + + assert(classifier.getMaxDepth === 5) + assert(classifier.getEta === 0.2) + assert(classifier.getObjective === "binary:logistic") + assert(classifier.getFeaturesCol === "abc") + assert(classifier.getMissing === 0.2f) + assert(classifier.getAlpha === 0.97) + + classifier.setEta(0.66).setMaxDepth(7) + assert(classifier.getMaxDepth === 7) + assert(classifier.getEta === 0.66) + } + + test("XGBoostClassifier copy") { + val classifier = new XGBoostClassifier().setNthread(2).setNumWorkers(10) + val classifierCopied = classifier.copy(ParamMap.empty) + + assert(classifier.uid === classifierCopied.uid) + assert(classifier.getNthread === classifierCopied.getNthread) + assert(classifier.getNumWorkers === classifier.getNumWorkers) + } + + test("XGBoostClassification copy") { + val model = new XGBoostClassificationModel("hello").setNthread(2).setNumWorkers(10) + val modelCopied = model.copy(ParamMap.empty) + assert(model.uid === modelCopied.uid) + assert(model.getNthread === modelCopied.getNthread) + assert(model.getNumWorkers === modelCopied.getNumWorkers) + } + + test("read/write") { + val trainDf = smallBinaryClassificationVector + val xgbParams: Map[String, Any] = Map( + "max_depth" -> 5, + "eta" -> 0.2, + "objective" -> "binary:logistic" + ) + + def check(xgboostParams: XGBoostParams[_]): Unit = { + assert(xgboostParams.getMaxDepth === 5) + assert(xgboostParams.getEta === 0.2) + assert(xgboostParams.getObjective === "binary:logistic") + } + + val classifierPath = new File(tempDir.toFile, "classifier").getPath + val classifier = new XGBoostClassifier(xgbParams) + check(classifier) + + classifier.write.overwrite().save(classifierPath) + val loadedClassifier = XGBoostClassifier.load(classifierPath) + check(loadedClassifier) + + val model = loadedClassifier.fit(trainDf) + check(model) + + val modelPath = new File(tempDir.toFile, "model").getPath + model.write.overwrite().save(modelPath) + val modelLoaded = XGBoostClassificationModel.load(modelPath) + check(modelLoaded) + } + + test("pipeline") { val spark = ss var df = spark.read.parquet("/home/bobwang/data/iris/parquet") @@ -57,7 +133,7 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS // df = df.withColumn("base_margin", lit(20)) // .withColumn("weight", rand(1)) - // Assemble the feature columns into a single vector column + // Assemble the feature columns into a single vector column val assembler = new VectorAssembler() .setInputCols(features) .setOutputCol("features") @@ -65,10 +141,10 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS var Array(trainDf, validationDf) = dataset.randomSplit(Array(0.8, 0.2), seed = 1) -// trainDf = trainDf.withColumn("validation", lit(false)) -// validationDf = validationDf.withColumn("validationDf", lit(true)) + // trainDf = trainDf.withColumn("validation", lit(false)) + // validationDf = validationDf.withColumn("validationDf", lit(true)) -// df = trainDf.union(validationDf) + // df = trainDf.union(validationDf) // val arrayInput = df.select(array(features.map(col(_)): _*).as("features"), // col("label"), col("base_margin")) @@ -81,7 +157,7 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS // .setBaseMarginCol("base_margin") .setLabelCol(labelCol) .setEvalDataset(validationDf) -// .setValidationIndicatorCol("validation") + // .setValidationIndicatorCol("validation") // .setPredictionCol("") .setRawPredictionCol("") .setProbabilityCol("xxxx") diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index 5031b00d924d..280d17e6a287 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -16,6 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.ml.linalg.Vectors import org.scalatest.funsuite.AnyFunSuite @@ -117,6 +119,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu .setLabelCol("label") .setFeaturesCol("features") .setWeightCol("weight") + .setNumWorkers(2) val (df, indices) = classifier.preprocess(dataset) val rdd = classifier.toXGBLabeledPoint(df, indices) @@ -153,6 +156,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu .setLabelCol("label") .setFeaturesCol("features") .setWeightCol("weight") + .setNumWorkers(2) val (df, indices) = classifier.preprocess(dataset) val rdd = classifier.toXGBLabeledPoint(df, indices) @@ -188,6 +192,8 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu .setLabelCol("label") .setFeaturesCol("features") .setWeightCol("weight") + .setBaseMarginCol("margin") + .setNumWorkers(2) .setMissing(0.0f) val (df, indices) = classifier.preprocess(dataset) @@ -196,9 +202,9 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu assert(result.length == 2) - assert(result(0).label === 1.0f && result(0).baseMargin.isNaN && + assert(result(0).label === 1.0f && result(0).baseMargin === 0.5f && result(0).weight === 1.0f && result(0).values === data(0).map(_.toFloat)) - assert(result(1).label === 3.0f && result(1).baseMargin.isNaN && + assert(result(1).label === 3.0f && result(1).baseMargin === -0.5f && result(1).weight == 0.0f) assert(result(1).values(0) === 12.0f) @@ -208,4 +214,130 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu assert(result(1).values(4) === 15.0f) } + test("test to RDD watches") { + val data = Array( + Array(1.0, 2.0, 3.0, 4.0, 5.0), + Array(0.0, 0.0, 0.0, 0.0, 2.0), + Array(12.0, 13.0, 14.0, 14.0, 15.0), + Array(20.5, 21.2, 0.0, 0.0, 2.0) + ) + val dataset = ss.createDataFrame(sc.parallelize(Seq( + (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"), + (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"), + (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"), + (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c"), + ))).toDF("label", "group", "margin", "weight", "features", "other") + + val classifier = new XGBoostClassifier() + .setLabelCol("label") + .setFeaturesCol("features") + .setWeightCol("weight") + .setBaseMarginCol("margin") + .setNumWorkers(2) + + val (df, indices) = classifier.preprocess(dataset) + val rdd = classifier.toRdd(df, indices) + val result = rdd.mapPartitions { iter => + if (iter.hasNext) { + val watches = iter.next() + val size = watches.size + val rowNum = watches.datasets(0).rowNum + val labels = watches.datasets(0).getLabel + val weight = watches.datasets(0).getWeight + val margins = watches.datasets(0).getBaseMargin + watches.delete() + Iterator.single((size, rowNum, labels, weight, margins)) + } else { + Iterator.empty + } + }.collect() + + val labels: ArrayBuffer[Float] = ArrayBuffer.empty + val weight: ArrayBuffer[Float] = ArrayBuffer.empty + val margins: ArrayBuffer[Float] = ArrayBuffer.empty + + var totalRows = 0L + for (row <- result) { + assert(row._1 === 1) + totalRows = totalRows + row._2 + labels.append(row._3: _*) + weight.append(row._4: _*) + margins.append(row._5: _*) + } + assert(totalRows === 4) + assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted) + assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted) + assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted) + + } + + test("test to RDD watches with eval") { + val trainData = Array( + Array(-1.0, -2.0, -3.0, -4.0, -5.0), + Array(2.0, 2.0, 2.0, 3.0, -2.0), + Array(-12.0, -13.0, -14.0, -14.0, -15.0), + Array(-20.5, -21.2, 0.0, 0.0, 2.0) + ) + val trainDataset = ss.createDataFrame(sc.parallelize(Seq( + (11.0, 0, 0.15, 11.0, Vectors.dense(trainData(0)), "a"), + (12.0, 12, -0.15, 10.0, Vectors.dense(trainData(1)).toSparse, "b"), + (13.0, 12, -0.15, 10.0, Vectors.dense(trainData(2)), "b"), + (14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c"), + ))).toDF("label", "group", "margin", "weight", "features", "other") + val evalData = Array( + Array(1.0, 2.0, 3.0, 4.0, 5.0), + Array(0.0, 0.0, 0.0, 0.0, 2.0), + Array(12.0, 13.0, 14.0, 14.0, 15.0), + Array(20.5, 21.2, 0.0, 0.0, 2.0) + ) + val evalDataset = ss.createDataFrame(sc.parallelize(Seq( + (1.0, 0, 0.5, 1.0, Vectors.dense(evalData(0)), "a"), + (2.0, 2, -0.5, 0.0, Vectors.dense(evalData(1)).toSparse, "b"), + (3.0, 2, -0.5, 0.0, Vectors.dense(evalData(2)), "b"), + (4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c"), + ))).toDF("label", "group", "margin", "weight", "features", "other") + + val classifier = new XGBoostClassifier() + .setLabelCol("label") + .setFeaturesCol("features") + .setWeightCol("weight") + .setBaseMarginCol("margin") + .setEvalDataset(evalDataset) + .setNumWorkers(2) + + val (df, indices) = classifier.preprocess(trainDataset) + val rdd = classifier.toRdd(df, indices) + val result = rdd.mapPartitions { iter => + if (iter.hasNext) { + val watches = iter.next() + val size = watches.size + val rowNum = watches.datasets(1).rowNum + val labels = watches.datasets(1).getLabel + val weight = watches.datasets(1).getWeight + val margins = watches.datasets(1).getBaseMargin + watches.delete() + Iterator.single((size, rowNum, labels, weight, margins)) + } else { + Iterator.empty + } + }.collect() + + val labels: ArrayBuffer[Float] = ArrayBuffer.empty + val weight: ArrayBuffer[Float] = ArrayBuffer.empty + val margins: ArrayBuffer[Float] = ArrayBuffer.empty + + var totalRows = 0L + for (row <- result) { + assert(row._1 === 2) + totalRows = totalRows + row._2 + labels.append(row._3: _*) + weight.append(row._4: _*) + margins.append(row._5: _*) + } + assert(totalRows === 4) + assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted) + assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted) + assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted) + } + }